Exemple #1
0
def execute():  #pylint: disable=unused-variable
    lmax_option = ''
    if app.ARGS.lmax:
        lmax_option = ' -lmax ' + app.ARGS.lmax

    convergence_change = 0.01 * app.ARGS.convergence

    progress = app.ProgressBar('Optimising')

    iteration = 0
    while iteration < app.ARGS.max_iters:
        prefix = 'iter' + str(iteration) + '_'

        # How to initialise response function?
        # old dwi2response command used mean & standard deviation of DWI data; however
        #   this may force the output FODs to lmax=2 at the first iteration
        # Chantal used a tensor with low FA, but it'd be preferable to get the scaling right
        # Other option is to do as before, but get the ratio between l=0 and l=2, and
        #   generate l=4,6,... using that amplitude ratio
        if iteration == 0:
            rf_in_path = 'init_RF.txt'
            mask_in_path = 'mask.mif'

            # Grab the mean and standard deviation across all volumes in a single mrstats call
            # Also scale them to reflect the fact that we're moving to the SH basis
            image_stats = image.statistics('dwi.mif',
                                           mask='mask.mif',
                                           allvolumes=True)
            mean = image_stats.mean * math.sqrt(4.0 * math.pi)
            std = image_stats.std * math.sqrt(4.0 * math.pi)

            # Now produce the initial response function
            # Let's only do it to lmax 4
            init_rf = [
                str(mean),
                str(-0.5 * std),
                str(0.25 * std * std / mean)
            ]
            with open('init_RF.txt', 'w') as init_rf_file:
                init_rf_file.write(' '.join(init_rf))
        else:
            rf_in_path = 'iter' + str(iteration - 1) + '_RF.txt'
            mask_in_path = 'iter' + str(iteration - 1) + '_SF.mif'

        # Run CSD
        run.command('dwi2fod csd dwi.mif ' + rf_in_path + ' ' + prefix +
                    'FOD.mif -mask ' + mask_in_path)
        # Get amplitudes of two largest peaks, and directions of largest
        run.command('fod2fixel ' + prefix + 'FOD.mif ' + prefix +
                    'fixel -peak peaks.mif -mask ' + mask_in_path +
                    ' -fmls_no_thresholds')
        app.cleanup(prefix + 'FOD.mif')
        run.command('fixel2voxel ' + prefix + 'fixel/peaks.mif none ' +
                    prefix + 'amps.mif')
        run.command('mrconvert ' + prefix + 'amps.mif ' + prefix +
                    'first_peaks.mif -coord 3 0 -axes 0,1,2')
        run.command('mrconvert ' + prefix + 'amps.mif ' + prefix +
                    'second_peaks.mif -coord 3 1 -axes 0,1,2')
        app.cleanup(prefix + 'amps.mif')
        run.command('fixel2peaks ' + prefix + 'fixel/directions.mif ' +
                    prefix + 'first_dir.mif -number 1')
        app.cleanup(prefix + 'fixel')
        # Revise single-fibre voxel selection based on ratio of tallest to second-tallest peak
        run.command('mrcalc ' + prefix + 'second_peaks.mif ' + prefix +
                    'first_peaks.mif -div ' + prefix + 'peak_ratio.mif')
        app.cleanup(prefix + 'first_peaks.mif')
        app.cleanup(prefix + 'second_peaks.mif')
        run.command('mrcalc ' + prefix + 'peak_ratio.mif ' +
                    str(app.ARGS.peak_ratio) + ' -lt ' + mask_in_path +
                    ' -mult ' + prefix + 'SF.mif -datatype bit')
        app.cleanup(prefix + 'peak_ratio.mif')
        # Make sure image isn't empty
        sf_voxel_count = image.statistics(prefix + 'SF.mif',
                                          mask=prefix + 'SF.mif').count
        if not sf_voxel_count:
            raise MRtrixError(
                'Aborting: All voxels have been excluded from single-fibre selection'
            )
        # Generate a new response function
        run.command('amp2response dwi.mif ' + prefix + 'SF.mif ' + prefix +
                    'first_dir.mif ' + prefix + 'RF.txt' + lmax_option)
        app.cleanup(prefix + 'first_dir.mif')

        new_rf = matrix.load_vector(prefix + 'RF.txt')
        progress.increment('Optimising (' + str(iteration + 1) +
                           ' iterations, ' + str(sf_voxel_count) +
                           ' voxels, RF: [ ' + ', '.join('{:.3f}'.format(n)
                                                         for n in new_rf) +
                           '] )')

        # Detect convergence
        # Look for a change > some percentage - don't bother looking at the masks
        if iteration > 0:
            old_rf = matrix.load_vector(rf_in_path)
            reiterate = False
            for old_value, new_value in zip(old_rf, new_rf):
                mean = 0.5 * (old_value + new_value)
                diff = math.fabs(0.5 * (old_value - new_value))
                ratio = diff / mean
                if ratio > convergence_change:
                    reiterate = True
            if not reiterate:
                run.function(shutil.copyfile, prefix + 'RF.txt',
                             'response.txt')
                run.function(shutil.copyfile, prefix + 'SF.mif', 'voxels.mif')
                break

        app.cleanup(rf_in_path)
        app.cleanup(mask_in_path)

        iteration += 1

    progress.done()

    # If we've terminated due to hitting the iteration limiter, we still need to copy the output file(s) to the correct location
    if os.path.exists('response.txt'):
        app.console('Exited at iteration ' + str(iteration + 1) + ' with ' +
                    str(sf_voxel_count) +
                    ' SF voxels due to unchanged RF coefficients')
    else:
        app.console('Exited after maximum ' + str(app.ARGS.max_iters) +
                    ' iterations with ' + str(sf_voxel_count) + ' SF voxels')
        run.function(shutil.copyfile,
                     'iter' + str(app.ARGS.max_iters - 1) + '_RF.txt',
                     'response.txt')
        run.function(shutil.copyfile,
                     'iter' + str(app.ARGS.max_iters - 1) + '_SF.mif',
                     'voxels.mif')

    run.function(shutil.copyfile, 'response.txt',
                 path.from_user(app.ARGS.output, False))
    if app.ARGS.voxels:
        run.command('mrconvert voxels.mif ' + path.from_user(app.ARGS.voxels),
                    mrconvert_keyval=path.from_user(app.ARGS.input, False),
                    force=app.FORCE_OVERWRITE)
Exemple #2
0
def execute():  #pylint: disable=unused-variable
    lmax_option = ''
    if app.ARGS.lmax:
        lmax_option = ' -lmax ' + app.ARGS.lmax

    if app.ARGS.max_iters < 2:
        raise MRtrixError('Number of iterations must be at least 2')

    progress = app.ProgressBar('Optimising')

    iter_voxels = app.ARGS.iter_voxels
    if iter_voxels == 0:
        iter_voxels = 10 * app.ARGS.number
    elif iter_voxels < app.ARGS.number:
        raise MRtrixError(
            'Number of selected voxels (-iter_voxels) must be greater than number of voxels desired (-number)'
        )

    iteration = 0
    while iteration < app.ARGS.max_iters:
        prefix = 'iter' + str(iteration) + '_'

        if iteration == 0:
            rf_in_path = 'init_RF.txt'
            mask_in_path = 'mask.mif'
            init_rf = '1 -1 1'
            with open(rf_in_path, 'w') as init_rf_file:
                init_rf_file.write(init_rf)
            iter_lmax_option = ' -lmax 4'
        else:
            rf_in_path = 'iter' + str(iteration - 1) + '_RF.txt'
            mask_in_path = 'iter' + str(iteration - 1) + '_SF_dilated.mif'
            iter_lmax_option = lmax_option

        # Run CSD
        run.command('dwi2fod csd dwi.mif ' + rf_in_path + ' ' + prefix +
                    'FOD.mif -mask ' + mask_in_path)
        # Get amplitudes of two largest peaks, and direction of largest
        run.command('fod2fixel ' + prefix + 'FOD.mif ' + prefix +
                    'fixel -peak peaks.mif -mask ' + mask_in_path +
                    ' -fmls_no_thresholds')
        app.cleanup(prefix + 'FOD.mif')
        if iteration:
            app.cleanup(mask_in_path)
        run.command('fixel2voxel ' + prefix + 'fixel/peaks.mif none ' +
                    prefix + 'amps.mif -number 2')
        run.command('mrconvert ' + prefix + 'amps.mif ' + prefix +
                    'first_peaks.mif -coord 3 0 -axes 0,1,2')
        run.command('mrconvert ' + prefix + 'amps.mif ' + prefix +
                    'second_peaks.mif -coord 3 1 -axes 0,1,2')
        app.cleanup(prefix + 'amps.mif')
        run.command('fixel2peaks ' + prefix + 'fixel/directions.mif ' +
                    prefix + 'first_dir.mif -number 1')
        app.cleanup(prefix + 'fixel')
        # Calculate the 'cost function' Donald derived for selecting single-fibre voxels
        # https://github.com/MRtrix3/mrtrix3/pull/426
        #  sqrt(|peak1|) * (1 - |peak2| / |peak1|)^2
        run.command('mrcalc ' + prefix + 'first_peaks.mif -sqrt 1 ' + prefix +
                    'second_peaks.mif ' + prefix +
                    'first_peaks.mif -div -sub 2 -pow -mult ' + prefix +
                    'CF.mif')
        app.cleanup(prefix + 'first_peaks.mif')
        app.cleanup(prefix + 'second_peaks.mif')
        voxel_count = image.statistics(prefix + 'CF.mif').count
        # Select the top-ranked voxels
        run.command('mrthreshold ' + prefix + 'CF.mif -top ' +
                    str(min([app.ARGS.number, voxel_count])) + ' ' + prefix +
                    'SF.mif')
        # Generate a new response function based on this selection
        run.command('amp2response dwi.mif ' + prefix + 'SF.mif ' + prefix +
                    'first_dir.mif ' + prefix + 'RF.txt' + iter_lmax_option)
        app.cleanup(prefix + 'first_dir.mif')

        new_rf = matrix.load_vector(prefix + 'RF.txt')
        progress.increment('Optimising (' + str(iteration + 1) +
                           ' iterations, RF: [ ' + ', '.join('{:.3f}'.format(n)
                                                             for n in new_rf) +
                           '] )')

        # Should we terminate?
        if iteration > 0:
            run.command('mrcalc ' + prefix + 'SF.mif iter' +
                        str(iteration - 1) + '_SF.mif -sub ' + prefix +
                        'SF_diff.mif')
            app.cleanup('iter' + str(iteration - 1) + '_SF.mif')
            max_diff = image.statistics(prefix + 'SF_diff.mif').max
            app.cleanup(prefix + 'SF_diff.mif')
            if not max_diff:
                app.cleanup(prefix + 'CF.mif')
                run.function(shutil.copyfile, prefix + 'RF.txt',
                             'response.txt')
                run.function(shutil.move, prefix + 'SF.mif', 'voxels.mif')
                break

        # Select a greater number of top single-fibre voxels, and dilate (within bounds of initial mask);
        #   these are the voxels that will be re-tested in the next iteration
        run.command('mrthreshold ' + prefix + 'CF.mif -top ' +
                    str(min([iter_voxels, voxel_count])) +
                    ' - | maskfilter - dilate - -npass ' +
                    str(app.ARGS.dilate) + ' | mrcalc mask.mif - -mult ' +
                    prefix + 'SF_dilated.mif')
        app.cleanup(prefix + 'CF.mif')

        iteration += 1

    progress.done()

    # If terminating due to running out of iterations, still need to put the results in the appropriate location
    if os.path.exists('response.txt'):
        app.console(
            'Convergence of SF voxel selection detected at iteration ' +
            str(iteration + 1))
    else:
        app.console('Exiting after maximum ' + str(app.ARGS.max_iters) +
                    ' iterations')
        run.function(shutil.copyfile,
                     'iter' + str(app.ARGS.max_iters - 1) + '_RF.txt',
                     'response.txt')
        run.function(shutil.move,
                     'iter' + str(app.ARGS.max_iters - 1) + '_SF.mif',
                     'voxels.mif')

    run.function(shutil.copyfile, 'response.txt',
                 path.from_user(app.ARGS.output, False))
    if app.ARGS.voxels:
        run.command('mrconvert voxels.mif ' + path.from_user(app.ARGS.voxels),
                    mrconvert_keyval=path.from_user(app.ARGS.input, False),
                    force=app.FORCE_OVERWRITE)
Exemple #3
0
def execute():  #pylint: disable=unused-variable
    bzero_threshold = float(
        CONFIG['BZeroThreshold']) if 'BZeroThreshold' in CONFIG else 10.0

    # CHECK INPUTS AND OPTIONS
    app.console('-------')

    # Get b-values and number of volumes per b-value.
    bvalues = [
        int(round(float(x)))
        for x in image.mrinfo('dwi.mif', 'shell_bvalues').split()
    ]
    bvolumes = [int(x) for x in image.mrinfo('dwi.mif', 'shell_sizes').split()]
    app.console(
        str(len(bvalues)) + ' unique b-value(s) detected: ' +
        ','.join(map(str, bvalues)) + ' with ' + ','.join(map(str, bvolumes)) +
        ' volumes')
    if len(bvalues) < 2:
        raise MRtrixError('Need at least 2 unique b-values (including b=0).')
    bvalues_option = ' -shells ' + ','.join(map(str, bvalues))

    # Get lmax information (if provided).
    sfwm_lmax = []
    if app.ARGS.lmax:
        sfwm_lmax = [int(x.strip()) for x in app.ARGS.lmax.split(',')]
        if not len(sfwm_lmax) == len(bvalues):
            raise MRtrixError('Number of lmax\'s (' + str(len(sfwm_lmax)) +
                              ', as supplied to the -lmax option: ' +
                              ','.join(map(str, sfwm_lmax)) +
                              ') does not match number of unique b-values.')
        for sfl in sfwm_lmax:
            if sfl % 2:
                raise MRtrixError(
                    'Values supplied to the -lmax option must be even.')
            if sfl < 0:
                raise MRtrixError(
                    'Values supplied to the -lmax option must be non-negative.'
                )
    sfwm_lmax_option = ''
    if sfwm_lmax:
        sfwm_lmax_option = ' -lmax ' + ','.join(map(str, sfwm_lmax))

    # PREPARATION
    app.console('-------')
    app.console('Preparation:')

    # Erode (brain) mask.
    if app.ARGS.erode > 0:
        app.console('* Eroding brain mask by ' + str(app.ARGS.erode) +
                    ' pass(es)...')
        run.command('maskfilter mask.mif erode eroded_mask.mif -npass ' +
                    str(app.ARGS.erode),
                    show=False)
    else:
        app.console('Not eroding brain mask.')
        run.command('mrconvert mask.mif eroded_mask.mif -datatype bit',
                    show=False)
    statmaskcount = image.statistics('mask.mif', mask='mask.mif').count
    statemaskcount = image.statistics('eroded_mask.mif',
                                      mask='eroded_mask.mif').count
    app.console('  [ mask: ' + str(statmaskcount) + ' -> ' +
                str(statemaskcount) + ' ]')

    # Get volumes, compute mean signal and SDM per b-value; compute overall SDM; get rid of erroneous values.
    app.console('* Computing signal decay metric (SDM):')
    totvolumes = 0
    fullsdmcmd = 'mrcalc'
    errcmd = 'mrcalc'
    zeropath = 'mean_b' + str(bvalues[0]) + '.mif'
    for ibv, bval in enumerate(bvalues):
        app.console(' * b=' + str(bval) + '...')
        meanpath = 'mean_b' + str(bval) + '.mif'
        run.command('dwiextract dwi.mif -shells ' + str(bval) +
                    ' - | mrcalc - 0 -max - | mrmath - mean ' + meanpath +
                    ' -axis 3',
                    show=False)
        errpath = 'err_b' + str(bval) + '.mif'
        run.command('mrcalc ' + meanpath + ' -finite ' + meanpath +
                    ' 0 -if 0 -le ' + errpath + ' -datatype bit',
                    show=False)
        errcmd += ' ' + errpath
        if ibv > 0:
            errcmd += ' -add'
            sdmpath = 'sdm_b' + str(bval) + '.mif'
            run.command('mrcalc ' + zeropath + ' ' + meanpath +
                        ' -divide -log ' + sdmpath,
                        show=False)
            totvolumes += bvolumes[ibv]
            fullsdmcmd += ' ' + sdmpath + ' ' + str(bvolumes[ibv]) + ' -mult'
            if ibv > 1:
                fullsdmcmd += ' -add'
    fullsdmcmd += ' ' + str(totvolumes) + ' -divide full_sdm.mif'
    run.command(fullsdmcmd, show=False)
    app.console('* Removing erroneous voxels from mask and correcting SDM...')
    run.command(
        'mrcalc full_sdm.mif -finite full_sdm.mif 0 -if 0 -le err_sdm.mif -datatype bit',
        show=False)
    errcmd += ' err_sdm.mif -add 0 eroded_mask.mif -if safe_mask.mif -datatype bit'
    run.command(errcmd, show=False)
    run.command('mrcalc safe_mask.mif full_sdm.mif 0 -if 10 -min safe_sdm.mif',
                show=False)
    statsmaskcount = image.statistics('safe_mask.mif',
                                      mask='safe_mask.mif').count
    app.console('  [ mask: ' + str(statemaskcount) + ' -> ' +
                str(statsmaskcount) + ' ]')

    # CRUDE SEGMENTATION
    app.console('-------')
    app.console('Crude segmentation:')

    # Compute FA and principal eigenvectors; crude WM versus GM-CSF separation based on FA.
    app.console('* Crude WM versus GM-CSF separation (at FA=' +
                str(app.ARGS.fa) + ')...')
    run.command(
        'dwi2tensor dwi.mif - -mask safe_mask.mif | tensor2metric - -fa safe_fa.mif -vector safe_vecs.mif -modulate none -mask safe_mask.mif',
        show=False)
    run.command('mrcalc safe_mask.mif safe_fa.mif 0 -if ' + str(app.ARGS.fa) +
                ' -gt crude_wm.mif -datatype bit',
                show=False)
    run.command(
        'mrcalc crude_wm.mif 0 safe_mask.mif -if _crudenonwm.mif -datatype bit',
        show=False)
    statcrudewmcount = image.statistics('crude_wm.mif',
                                        mask='crude_wm.mif').count
    statcrudenonwmcount = image.statistics('_crudenonwm.mif',
                                           mask='_crudenonwm.mif').count
    app.console('  [ ' + str(statsmaskcount) + ' -> ' + str(statcrudewmcount) +
                ' (WM) & ' + str(statcrudenonwmcount) + ' (GM-CSF) ]')

    # Crude GM versus CSF separation based on SDM.
    app.console('* Crude GM versus CSF separation...')
    crudenonwmmedian = image.statistics('safe_sdm.mif',
                                        mask='_crudenonwm.mif').median
    run.command(
        'mrcalc _crudenonwm.mif safe_sdm.mif ' + str(crudenonwmmedian) +
        ' -subtract 0 -if - | mrthreshold - - -mask _crudenonwm.mif | mrcalc _crudenonwm.mif - 0 -if crude_csf.mif -datatype bit',
        show=False)
    run.command(
        'mrcalc crude_csf.mif 0 _crudenonwm.mif -if crude_gm.mif -datatype bit',
        show=False)
    statcrudegmcount = image.statistics('crude_gm.mif',
                                        mask='crude_gm.mif').count
    statcrudecsfcount = image.statistics('crude_csf.mif',
                                         mask='crude_csf.mif').count
    app.console('  [ ' + str(statcrudenonwmcount) + ' -> ' +
                str(statcrudegmcount) + ' (GM) & ' + str(statcrudecsfcount) +
                ' (CSF) ]')

    # REFINED SEGMENTATION
    app.console('-------')
    app.console('Refined segmentation:')

    # Refine WM: remove high SDM outliers.
    app.console('* Refining WM...')
    crudewmmedian = image.statistics('safe_sdm.mif',
                                     mask='crude_wm.mif').median
    run.command('mrcalc crude_wm.mif safe_sdm.mif ' + str(crudewmmedian) +
                ' -subtract -abs 0 -if _crudewm_sdmad.mif',
                show=False)
    crudewmmad = image.statistics('_crudewm_sdmad.mif',
                                  mask='crude_wm.mif').median
    crudewmoutlthresh = crudewmmedian + (1.4826 * crudewmmad * 2.0)
    run.command('mrcalc crude_wm.mif safe_sdm.mif 0 -if ' +
                str(crudewmoutlthresh) +
                ' -gt _crudewmoutliers.mif -datatype bit',
                show=False)
    run.command(
        'mrcalc _crudewmoutliers.mif 0 crude_wm.mif -if refined_wm.mif -datatype bit',
        show=False)
    statrefwmcount = image.statistics('refined_wm.mif',
                                      mask='refined_wm.mif').count
    app.console('  [ WM: ' + str(statcrudewmcount) + ' -> ' +
                str(statrefwmcount) + ' ]')

    # Refine GM: separate safer GM from partial volumed voxels.
    app.console('* Refining GM...')
    crudegmmedian = image.statistics('safe_sdm.mif',
                                     mask='crude_gm.mif').median
    run.command('mrcalc crude_gm.mif safe_sdm.mif 0 -if ' +
                str(crudegmmedian) + ' -gt _crudegmhigh.mif -datatype bit',
                show=False)
    run.command(
        'mrcalc _crudegmhigh.mif 0 crude_gm.mif -if _crudegmlow.mif -datatype bit',
        show=False)
    run.command(
        'mrcalc _crudegmhigh.mif safe_sdm.mif ' + str(crudegmmedian) +
        ' -subtract 0 -if - | mrthreshold - - -mask _crudegmhigh.mif -invert | mrcalc _crudegmhigh.mif - 0 -if _crudegmhighselect.mif -datatype bit',
        show=False)
    run.command(
        'mrcalc _crudegmlow.mif safe_sdm.mif ' + str(crudegmmedian) +
        ' -subtract -neg 0 -if - | mrthreshold - - -mask _crudegmlow.mif -invert | mrcalc _crudegmlow.mif - 0 -if _crudegmlowselect.mif -datatype bit',
        show=False)
    run.command(
        'mrcalc _crudegmhighselect.mif 1 _crudegmlowselect.mif -if refined_gm.mif -datatype bit',
        show=False)
    statrefgmcount = image.statistics('refined_gm.mif',
                                      mask='refined_gm.mif').count
    app.console('  [ GM: ' + str(statcrudegmcount) + ' -> ' +
                str(statrefgmcount) + ' ]')

    # Refine CSF: recover lost CSF from crude WM SDM outliers, separate safer CSF from partial volumed voxels.
    app.console('* Refining CSF...')
    crudecsfmin = image.statistics('safe_sdm.mif', mask='crude_csf.mif').min
    run.command('mrcalc _crudewmoutliers.mif safe_sdm.mif 0 -if ' +
                str(crudecsfmin) +
                ' -gt 1 crude_csf.mif -if _crudecsfextra.mif -datatype bit',
                show=False)
    run.command(
        'mrcalc _crudecsfextra.mif safe_sdm.mif ' + str(crudecsfmin) +
        ' -subtract 0 -if - | mrthreshold - - -mask _crudecsfextra.mif | mrcalc _crudecsfextra.mif - 0 -if refined_csf.mif -datatype bit',
        show=False)
    statrefcsfcount = image.statistics('refined_csf.mif',
                                       mask='refined_csf.mif').count
    app.console('  [ CSF: ' + str(statcrudecsfcount) + ' -> ' +
                str(statrefcsfcount) + ' ]')

    # FINAL VOXEL SELECTION AND RESPONSE FUNCTION ESTIMATION
    app.console('-------')
    app.console('Final voxel selection and response function estimation:')

    # Get final voxels for CSF response function estimation from refined CSF.
    app.console('* CSF:')
    app.console(' * Selecting final voxels (' + str(app.ARGS.csf) +
                '% of refined CSF)...')
    voxcsfcount = int(round(statrefcsfcount * app.ARGS.csf / 100.0))
    run.command(
        'mrcalc refined_csf.mif safe_sdm.mif 0 -if - | mrthreshold - - -top ' +
        str(voxcsfcount) +
        ' -ignorezero | mrcalc refined_csf.mif - 0 -if - -datatype bit | mrconvert - voxels_csf.mif -axes 0,1,2',
        show=False)
    statvoxcsfcount = image.statistics('voxels_csf.mif',
                                       mask='voxels_csf.mif').count
    app.console('   [ CSF: ' + str(statrefcsfcount) + ' -> ' +
                str(statvoxcsfcount) + ' ]')
    # Estimate CSF response function
    app.console(' * Estimating response function...')
    run.command(
        'amp2response dwi.mif voxels_csf.mif safe_vecs.mif response_csf.txt' +
        bvalues_option + ' -isotropic',
        show=False)

    # Get final voxels for GM response function estimation from refined GM.
    app.console('* GM:')
    app.console(' * Selecting final voxels (' + str(app.ARGS.gm) +
                '% of refined GM)...')
    voxgmcount = int(round(statrefgmcount * app.ARGS.gm / 100.0))
    refgmmedian = image.statistics('safe_sdm.mif',
                                   mask='refined_gm.mif').median
    run.command(
        'mrcalc refined_gm.mif safe_sdm.mif ' + str(refgmmedian) +
        ' -subtract -abs 1 -add 0 -if - | mrthreshold - - -bottom ' +
        str(voxgmcount) +
        ' -ignorezero | mrcalc refined_gm.mif - 0 -if - -datatype bit | mrconvert - voxels_gm.mif -axes 0,1,2',
        show=False)
    statvoxgmcount = image.statistics('voxels_gm.mif',
                                      mask='voxels_gm.mif').count
    app.console('   [ GM: ' + str(statrefgmcount) + ' -> ' +
                str(statvoxgmcount) + ' ]')
    # Estimate GM response function
    app.console(' * Estimating response function...')
    run.command(
        'amp2response dwi.mif voxels_gm.mif safe_vecs.mif response_gm.txt' +
        bvalues_option + ' -isotropic',
        show=False)

    # Get final voxels for single-fibre WM response function estimation from refined WM.
    app.console('* Single-fibre WM:')
    app.console(' * Selecting final voxels' +
                ('' if app.ARGS.wm_algo == 'tax' else
                 (' (' + str(app.ARGS.sfwm) + '% of refined WM)')) + '...')
    voxsfwmcount = int(round(statrefwmcount * app.ARGS.sfwm / 100.0))

    if app.ARGS.wm_algo:
        recursive_cleanup_option = ''
        if not app.DO_CLEANUP:
            recursive_cleanup_option = ' -nocleanup'
        app.console('   Selecting WM single-fibre voxels using \'' +
                    app.ARGS.wm_algo + '\' algorithm')
        if app.ARGS.wm_algo == 'tax' and app.ARGS.sfwm != 0.5:
            app.warn(
                'Single-fibre WM response function selection algorithm "tax" will not honour requested WM voxel percentage'
            )
        run.command(
            'dwi2response ' + app.ARGS.wm_algo +
            ' dwi.mif _respsfwmss.txt -mask refined_wm.mif -voxels voxels_sfwm.mif'
            + ('' if app.ARGS.wm_algo == 'tax' else
               (' -number ' + str(voxsfwmcount))) + ' -scratch ' +
            path.quote(app.SCRATCH_DIR) + recursive_cleanup_option,
            show=False)
    else:
        app.console(
            '   Selecting WM single-fibre voxels using built-in (Dhollander et al., 2019) algorithm'
        )
        run.command('mrmath dwi.mif mean mean_sig.mif -axis 3', show=False)
        refwmcoef = image.statistics('mean_sig.mif',
                                     mask='refined_wm.mif').median * math.sqrt(
                                         4.0 * math.pi)
        if sfwm_lmax:
            isiso = [lm == 0 for lm in sfwm_lmax]
        else:
            isiso = [bv < bzero_threshold for bv in bvalues]
        with open('ewmrf.txt', 'w') as ewr:
            for iis in isiso:
                if iis:
                    ewr.write("%s 0 0 0\n" % refwmcoef)
                else:
                    ewr.write("%s -%s %s -%s\n" %
                              (refwmcoef, refwmcoef, refwmcoef, refwmcoef))
        run.command(
            'dwi2fod msmt_csd dwi.mif ewmrf.txt abs_ewm2.mif response_csf.txt abs_csf2.mif -mask refined_wm.mif -lmax 2,0'
            + bvalues_option,
            show=False)
        run.command(
            'mrconvert abs_ewm2.mif - -coord 3 0 | mrcalc - abs_csf2.mif -add abs_sum2.mif',
            show=False)
        run.command(
            'sh2peaks abs_ewm2.mif - -num 1 -mask refined_wm.mif | peaks2amp - - | mrcalc - abs_sum2.mif -divide - | mrconvert - metric_sfwm2.mif -coord 3 0 -axes 0,1,2',
            show=False)
        run.command(
            'mrcalc refined_wm.mif metric_sfwm2.mif 0 -if - | mrthreshold - - -top '
            + str(voxsfwmcount * 2) +
            ' -ignorezero | mrcalc refined_wm.mif - 0 -if - -datatype bit | mrconvert - refined_sfwm.mif -axes 0,1,2',
            show=False)
        run.command(
            'dwi2fod msmt_csd dwi.mif ewmrf.txt abs_ewm6.mif response_csf.txt abs_csf6.mif -mask refined_sfwm.mif -lmax 6,0'
            + bvalues_option,
            show=False)
        run.command(
            'mrconvert abs_ewm6.mif - -coord 3 0 | mrcalc - abs_csf6.mif -add abs_sum6.mif',
            show=False)
        run.command(
            'sh2peaks abs_ewm6.mif - -num 1 -mask refined_sfwm.mif | peaks2amp - - | mrcalc - abs_sum6.mif -divide - | mrconvert - metric_sfwm6.mif -coord 3 0 -axes 0,1,2',
            show=False)
        run.command(
            'mrcalc refined_sfwm.mif metric_sfwm6.mif 0 -if - | mrthreshold - - -top '
            + str(voxsfwmcount) +
            ' -ignorezero | mrcalc refined_sfwm.mif - 0 -if - -datatype bit | mrconvert - voxels_sfwm.mif -axes 0,1,2',
            show=False)

    statvoxsfwmcount = image.statistics('voxels_sfwm.mif',
                                        mask='voxels_sfwm.mif').count
    app.console('   [ WM: ' + str(statrefwmcount) + ' -> ' +
                str(statvoxsfwmcount) + ' (single-fibre) ]')
    # Estimate SF WM response function
    app.console(' * Estimating response function...')
    run.command(
        'amp2response dwi.mif voxels_sfwm.mif safe_vecs.mif response_sfwm.txt'
        + bvalues_option + sfwm_lmax_option,
        show=False)

    # OUTPUT AND SUMMARY
    app.console('-------')
    app.console('Generating outputs...')

    # Generate 4D binary images with voxel selections at major stages in algorithm (RGB: WM=blue, GM=green, CSF=red).
    run.command(
        'mrcat crude_csf.mif crude_gm.mif crude_wm.mif check_crude.mif -axis 3',
        show=False)
    run.command(
        'mrcat refined_csf.mif refined_gm.mif refined_wm.mif check_refined.mif -axis 3',
        show=False)
    run.command(
        'mrcat voxels_csf.mif voxels_gm.mif voxels_sfwm.mif check_voxels.mif -axis 3',
        show=False)

    # Copy results to output files
    run.function(shutil.copyfile,
                 'response_sfwm.txt',
                 path.from_user(app.ARGS.out_sfwm, False),
                 show=False)
    run.function(shutil.copyfile,
                 'response_gm.txt',
                 path.from_user(app.ARGS.out_gm, False),
                 show=False)
    run.function(shutil.copyfile,
                 'response_csf.txt',
                 path.from_user(app.ARGS.out_csf, False),
                 show=False)
    if app.ARGS.voxels:
        run.command('mrconvert check_voxels.mif ' +
                    path.from_user(app.ARGS.voxels),
                    mrconvert_keyval=path.from_user(app.ARGS.input, False),
                    force=app.FORCE_OVERWRITE,
                    show=False)
    app.console('-------')
Exemple #4
0
def execute(): #pylint: disable=unused-variable
  # Ideally want to use the oversampling-based regridding of the 5TT image from the SIFT model, not mrtransform
  # May need to commit 5ttregrid...

  # Verify input 5tt image
  verification_text = ''
  try:
    verification_text = run.command('5ttcheck 5tt.mif').stderr
  except run.MRtrixCmdError as except_5ttcheck:
    verification_text = except_5ttcheck.stderr
  if 'WARNING' in verification_text or 'ERROR' in verification_text:
    app.warn('Command 5ttcheck indicates problems with provided input 5TT image \'' + app.ARGS.in_5tt + '\':')
    for line in verification_text.splitlines():
      app.warn(line)
    app.warn('These may or may not interfere with the dwi2response msmt_5tt script')

  # Get shell information
  shells = [ int(round(float(x))) for x in image.mrinfo('dwi.mif', 'shell_bvalues').split() ]
  if len(shells) < 3:
    app.warn('Less than three b-values; response functions will not be applicable in resolving three tissues using MSMT-CSD algorithm')

  # Get lmax information (if provided)
  wm_lmax = [ ]
  if app.ARGS.lmax:
    wm_lmax = [ int(x.strip()) for x in app.ARGS.lmax.split(',') ]
    if not len(wm_lmax) == len(shells):
      raise MRtrixError('Number of manually-defined lmax\'s (' + str(len(wm_lmax)) + ') does not match number of b-values (' + str(len(shells)) + ')')
    for shell_l in wm_lmax:
      if shell_l % 2:
        raise MRtrixError('Values for lmax must be even')
      if shell_l < 0:
        raise MRtrixError('Values for lmax must be non-negative')

  run.command('dwi2tensor dwi.mif - -mask mask.mif | tensor2metric - -fa fa.mif -vector vector.mif')
  if not os.path.exists('dirs.mif'):
    run.function(shutil.copy, 'vector.mif', 'dirs.mif')
  run.command('mrtransform 5tt.mif 5tt_regrid.mif -template fa.mif -interp linear')

  # Basic tissue masks
  run.command('mrconvert 5tt_regrid.mif - -coord 3 2 -axes 0,1,2 | mrcalc - ' + str(app.ARGS.pvf) + ' -gt mask.mif -mult wm_mask.mif')
  run.command('mrconvert 5tt_regrid.mif - -coord 3 0 -axes 0,1,2 | mrcalc - ' + str(app.ARGS.pvf) + ' -gt fa.mif ' + str(app.ARGS.fa) + ' -lt -mult mask.mif -mult gm_mask.mif')
  run.command('mrconvert 5tt_regrid.mif - -coord 3 3 -axes 0,1,2 | mrcalc - ' + str(app.ARGS.pvf) + ' -gt fa.mif ' + str(app.ARGS.fa) + ' -lt -mult mask.mif -mult csf_mask.mif')

  # Revise WM mask to only include single-fibre voxels
  recursive_cleanup_option=''
  if not app.DO_CLEANUP:
    recursive_cleanup_option = ' -nocleanup'
  if not app.ARGS.sfwm_fa_threshold:
    app.console('Selecting WM single-fibre voxels using \'' + app.ARGS.wm_algo + '\' algorithm')
    run.command('dwi2response ' + app.ARGS.wm_algo + ' dwi.mif wm_ss_response.txt -mask wm_mask.mif -voxels wm_sf_mask.mif -scratch ' + path.quote(app.SCRATCH_DIR) + recursive_cleanup_option)
  else:
    app.console('Selecting WM single-fibre voxels using \'fa\' algorithm with a hard FA threshold of ' + str(app.ARGS.sfwm_fa_threshold))
    run.command('dwi2response fa dwi.mif wm_ss_response.txt -mask wm_mask.mif -threshold ' + str(app.ARGS.sfwm_fa_threshold) + ' -voxels wm_sf_mask.mif -scratch ' + path.quote(app.SCRATCH_DIR) + recursive_cleanup_option)

  # Check for empty masks
  wm_voxels  = image.statistics('wm_sf_mask.mif', mask='wm_sf_mask.mif').count
  gm_voxels  = image.statistics('gm_mask.mif',    mask='gm_mask.mif').count
  csf_voxels = image.statistics('csf_mask.mif',   mask='csf_mask.mif').count
  empty_masks = [ ]
  if not wm_voxels:
    empty_masks.append('WM')
  if not gm_voxels:
    empty_masks.append('GM')
  if not csf_voxels:
    empty_masks.append('CSF')
  if empty_masks:
    message = ','.join(empty_masks)
    message += ' tissue mask'
    if len(empty_masks) > 1:
      message += 's'
    message += ' empty; cannot estimate response function'
    if len(empty_masks) > 1:
      message += 's'
    raise MRtrixError(message)

  # For each of the three tissues, generate a multi-shell response
  bvalues_option = ' -shells ' + ','.join(map(str,shells))
  sfwm_lmax_option = ''
  if wm_lmax:
    sfwm_lmax_option = ' -lmax ' + ','.join(map(str,wm_lmax))
  run.command('amp2response dwi.mif wm_sf_mask.mif dirs.mif wm.txt' + bvalues_option + sfwm_lmax_option)
  run.command('amp2response dwi.mif gm_mask.mif dirs.mif gm.txt' + bvalues_option + ' -isotropic')
  run.command('amp2response dwi.mif csf_mask.mif dirs.mif csf.txt' + bvalues_option + ' -isotropic')
  run.function(shutil.copyfile, 'wm.txt',  path.from_user(app.ARGS.out_wm,  False))
  run.function(shutil.copyfile, 'gm.txt',  path.from_user(app.ARGS.out_gm,  False))
  run.function(shutil.copyfile, 'csf.txt', path.from_user(app.ARGS.out_csf, False))

  # Generate output 4D binary image with voxel selections; RGB as in MSMT-CSD paper
  run.command('mrcat csf_mask.mif gm_mask.mif wm_sf_mask.mif voxels.mif -axis 3')
  if app.ARGS.voxels:
    run.command('mrconvert voxels.mif ' + path.from_user(app.ARGS.voxels), mrconvert_keyval=path.from_user(app.ARGS.input, False), force=app.FORCE_OVERWRITE)