Beispiel #1
0
def s3_4_edi_consensus(params, edges, inputs=[]):
    import time
    from subscripts.utilities import run, smart_remove, smart_mkdir, write, is_float, record_apptime
    from os.path import join, exists, split
    sdir = params['sdir']
    stdout = params['stdout']
    container = params['container']
    start_time = time.time()
    pbtk_dir = join(sdir, "EDI", "PBTKresults")
    consensus_dir = join(pbtk_dir, "twoway_consensus_edges")

    for edge in edges:
        a, b = edge
        a_to_b = "{}_to_{}".format(a, b)
        a_to_b_file = join(pbtk_dir, "{}_s2fato{}_s2fa.nii.gz".format(a, b))
        b_to_a_file = join(pbtk_dir, "{}_s2fato{}_s2fa.nii.gz".format(b, a))
        if not exists(a_to_b_file):
            write(stdout, "Error: cannot find {}".format(a_to_b_file))
            return
        if not exists(b_to_a_file):
            write(stdout, "Error: cannot find {}".format(b_to_a_file))
            return
        consensus = join(consensus_dir, a_to_b + '.nii.gz')
        amax = run("fslstats {} -R | cut -f 2 -d \" \" ".format(a_to_b_file),
                   params).strip()
        if not is_float(amax):
            write(
                stdout,
                "Error: fslstats on {} returns invalid value {}".format(
                    a_to_b_file, amax))
            return
        amax = int(float(amax))
        bmax = run("fslstats {} -R | cut -f 2 -d \" \" ".format(b_to_a_file),
                   params).strip()
        if not is_float(bmax):
            write(
                stdout,
                "Error: fslstats on {} returns invalid value {}".format(
                    b_to_a_file, bmax))
            return
        bmax = int(float(bmax))
        write(stdout, "amax = {}, bmax = {}".format(amax, bmax))
        if amax > 0 and bmax > 0:
            tmp1 = join(pbtk_dir, "{}_to_{}_tmp1.nii.gz".format(a, b))
            tmp2 = join(pbtk_dir, "{}_to_{}_tmp2.nii.gz".format(b, a))
            run("fslmaths {} -thrP 5 -bin {}".format(a_to_b_file, tmp1),
                params)
            run("fslmaths {} -thrP 5 -bin {}".format(b_to_a_file, tmp2),
                params)
            run(
                "fslmaths {} -add {} -thr 1 -bin {}".format(
                    tmp1, tmp2, consensus), params)
            smart_remove(tmp1)
            smart_remove(tmp2)
        else:
            with open(join(pbtk_dir, "zerosl.txt"), 'a') as log:
                log.write("For edge {}:\n".format(a_to_b))
                log.write("{} is thresholded to {}\n".format(a, amax))
                log.write("{} is thresholded to {}\n".format(b, bmax))
    record_apptime(params, start_time, 3)
Beispiel #2
0
def s_2_debug(params, inputs=[]):
    import time
    from subscripts.utilities import record_apptime, record_finish, write
    start_time = time.time()
    time.sleep(15)
    record_apptime(params, start_time, 2)
    record_finish(params)
Beispiel #3
0
def s2a_bedpostx(params, inputs=[]):
    import time
    from subscripts.utilities import run,smart_mkdir,smart_remove,write,record_start,record_apptime,record_finish,update_permissions,validate
    from os.path import exists,join,split
    from shutil import copyfile,rmtree
    sdir = params['sdir']
    stdout = params['stdout']
    container = params['container']
    cores_per_task = params['cores_per_task']
    use_gpu = params['use_gpu']
    group = params['group']
    record_start(params)
    start_time = time.time()
    bedpostx = join(sdir,"bedpostx_b1000")
    bedpostxResults = join(sdir,"bedpostx_b1000.bedpostX")
    th1 = join(bedpostxResults, "merged_th1samples")
    ph1 = join(bedpostxResults, "merged_ph1samples")
    th2 = join(bedpostxResults, "merged_th2samples")
    ph2 = join(bedpostxResults, "merged_ph2samples")
    dyads1 = join(bedpostxResults, "dyads1")
    dyads2 = join(bedpostxResults, "dyads2")
    brain_mask = join(bedpostxResults, "nodif_brain_mask")
    if exists(bedpostxResults):
        rmtree(bedpostxResults)
    smart_mkdir(bedpostx)
    smart_mkdir(bedpostxResults)
    copyfile(join(sdir,"data_eddy.nii.gz"),join(bedpostx,"data.nii.gz"))
    copyfile(join(sdir,"data_bet_mask.nii.gz"),join(bedpostx,"nodif_brain_mask.nii.gz"))
    copyfile(join(sdir,"bvals"),join(bedpostx,"bvals"))
    copyfile(join(sdir,"bvecs"),join(bedpostx,"bvecs"))

    if use_gpu:
        write(stdout, "Running Bedpostx with GPU")
        bedpostx_sh = join(sdir, "bedpostx.sh")
        smart_remove(bedpostx_sh)
        odir = split(sdir)[0]
        write(bedpostx_sh, "export CUDA_LIB_DIR=$CUDA_8_LIB_DIR\n" +
                           "export LD_LIBRARY_PATH=$CUDA_LIB_DIR:$LD_LIBRARY_PATH")
        if container:
            write(bedpostx_sh, "bedpostx_gpu {} -NJOBS 4".format(bedpostx.replace(odir, "/share")))
        else:
            write(bedpostx_sh, "bedpostx_gpu {} -NJOBS 4".format(bedpostx))
        run("sh " + bedpostx_sh, params)
        # hacky validation step
        with open(stdout) as f:
            log_content = f.read()
            for i in range(1, 5):
                assert("{:d} parts processed out of 4".format(i) in log_content)
    else:
        write(stdout, "Running Bedpostx without GPU")
        run("bedpostx {}".format(bedpostx), params)
    run("make_dyadic_vectors {} {} {} {}".format(th1,ph1,brain_mask,dyads1), params)
    run("make_dyadic_vectors {} {} {} {}".format(th2,ph2,brain_mask,dyads2), params)
    validate(th1, params)
    validate(ph1, params)
    validate(dyads1, params)
    update_permissions(params)
    record_apptime(params, start_time, 1)
    record_finish(params)
Beispiel #4
0
def s2b_1_recon_all(params, inputs=[]):
    import time
    from copy import deepcopy
    from subscripts.utilities import run, smart_mkdir, smart_remove, write, record_apptime, record_start, copy_dir
    from os import environ
    from os.path import exists, join, split, basename
    # work_sdir = params['work_sdir']
    # if work_sdir:
    #     old_sdir = params['sdir']
    #     copy_dir(old_sdir, work_sdir)
    #     params = deepcopy(params) # don't modify original param dict
    #     params['sdir'] = work_sdir
    sdir = params['sdir']
    stdout = params['stdout']
    container = params['container']
    cores_per_task = params['cores_per_task']
    use_gpu = params['use_gpu']
    group = params['group']
    subject = split(sdir)[1]
    record_start(params)
    start_time = time.time()
    T1 = join(sdir, "T1.nii.gz")
    if not exists(T1):
        raise Exception('Missing T1 file at {}'.format(T1))
    mri_out = join(sdir, "mri", "orig", "001.mgz")
    smart_mkdir(join(sdir, "mri"))
    smart_mkdir(join(sdir, "mri", "orig"))
    run("mri_convert {} {}".format(T1, mri_out), params)

    if not container:
        environ['SUBJECTS_DIR'] = split(sdir)[0]
    else:
        pass  # SUBJECTS_DIR already set to /share in recipe at [REPO]/container/Singularity

    if use_gpu:
        write(
            stdout,
            "Running Freesurfer with GPU and {} cores".format(cores_per_task))
        freesurfer_sh = join(sdir, "freesurfer.sh")
        smart_remove(freesurfer_sh)
        write(
            freesurfer_sh, "export CUDA_LIB_DIR=$CUDA_5_LIB_DIR\n" +
            "export LD_LIBRARY_PATH=$CUDA_LIB_DIR:$LD_LIBRARY_PATH\n" +
            "recon-all -s {} -all -notal-check -no-isrunning -use-gpu -parallel -openmp {}"
            .format(subject, cores_per_task))
        run("sh " + freesurfer_sh, params)
    elif cores_per_task > 1:
        write(stdout,
              "Running Freesurfer with {} cores".format(cores_per_task))
        run(
            "recon-all -s {} -all -notal-check -no-isrunning -parallel -openmp {}"
            .format(subject, cores_per_task), params)
    else:
        write(stdout, "Running Freesurfer with a single core")
        run("recon-all -s {} -all -notal-check -no-isrunning".format(subject),
            params)
    record_apptime(params, start_time, 1)
Beispiel #5
0
def s1_4_dti_fit(params, inputs=[]):
    import time
    from subscripts.utilities import run, smart_remove, record_apptime, record_finish, update_permissions
    from os.path import join, exists
    from shutil import copyfile
    from glob import glob
    sdir = params['sdir']
    stdout = params['stdout']
    container = params['container']
    cores_per_task = params['cores_per_task']
    start_time = time.time()
    output_prefix = join(sdir, "data_eddy")
    output_data = join(sdir, "data_eddy.nii.gz")
    timeslices = glob("{}_tmp????.nii.gz".format(output_prefix))
    timeslices.sort()
    bet = join(sdir, "data_bet.nii.gz")
    bvecs = join(sdir, "bvecs")
    bvals = join(sdir, "bvals")
    bet_mask = join(sdir, "data_bet_mask.nii.gz")
    dti_params = join(sdir, "DTIparams")
    dti_L1 = dti_params + "_L1.nii.gz"
    dti_L2 = dti_params + "_L2.nii.gz"
    dti_L3 = dti_params + "_L3.nii.gz"
    dti_MD = dti_params + "_MD.nii.gz"
    dti_RD = dti_params + "_RD.nii.gz"
    dti_MD = dti_params + "_MD.nii.gz"
    dti_AD = dti_params + "_AD.nii.gz"
    dti_FA = dti_params + "_FA.nii.gz"
    FA = join(sdir, "FA.nii.gz")
    run("fslmerge -t {} {}".format(output_data, " ".join(timeslices)), params)
    run("bet {} {} -m -f 0.3".format(output_data, bet), params)

    if exists(bet_mask):
        run(
            "dtifit --verbose -k {} -o {} -m {} -r {} -b {}".format(
                output_data, dti_params, bet_mask, bvecs, bvals), params)
        run(
            "fslmaths {} -add {} -add {} -div 3 {}".format(
                dti_L1, dti_L2, dti_L3, dti_MD), params)
        run("fslmaths {} -add {} -div 2 {}".format(dti_L2, dti_L3, dti_RD),
            params)
        copyfile(dti_L1, dti_AD)
        copyfile(dti_FA, FA)
    else:
        write(stdout, "Warning: failed to generate masked outputs")
        raise Exception(
            f"Failed BET step. Please check {stdout} for more info.")

    for i in glob("{}_tmp????.*".format(output_prefix)):
        smart_remove(i)
    for j in glob("{}_ref*".format(output_prefix)):
        smart_remove(j)
    update_permissions(params)
    record_apptime(params, start_time, 4)
    record_finish(params)
Beispiel #6
0
def s_1_debug(params, inputs=[]):
    import time
    from subscripts.utilities import record_start, record_apptime, write
    record_start(params)
    start_time = time.time()
    sdir = params['sdir']
    container = params['container']
    if container:
        run(
            "echo 'Testing Singularity on compute node\nShare dir is {}'".
            format(sdir), params)
    time.sleep(10)
    record_apptime(params, start_time, 1)
Beispiel #7
0
def s4_2_render_target(params, input_file, output_file, inputs=[]):
    import time
    from subscripts.utilities import record_apptime, run, write
    from os.path import splitext, join, exists
    start_time = time.time()
    sdir = params['sdir']
    run_vtk = join(sdir, 'run_vtk.py')
    stdout = params['stdout']
    histogram_bin_count = params['histogram_bin_count']
    if not exists(input_file):
        write(stdout, "Cannot find input file {}".format(input_file))
        return
    output_name = splitext(output_file)[0].strip()
    run(
        '/opt/vtk/bin/vtkpython {} {} {} {}'.format(run_vtk, input_file,
                                                    output_file,
                                                    histogram_bin_count),
        params)
    record_apptime(params, start_time, 1)
Beispiel #8
0
def s1_2_split_timeslices(params, inputs=[]):
    import time
    from subscripts.utilities import run,record_apptime,smart_remove,smart_copy
    from os.path import join
    from glob import glob
    sdir = params['sdir']
    stdout = params['stdout']
    container = params['container']
    start_time = time.time()
    output_prefix = join(sdir,"data_eddy")
    timeslices = glob("{}_tmp????.*".format(output_prefix))
    for i in timeslices:
        smart_remove(i)
    for j in glob("{}_ref*".format(output_prefix)):
        smart_remove(j)
    input_data = join(sdir, "hardi.nii.gz")
    output_prefix = join(sdir,"data_eddy")
    run("fslroi {} {}_ref 0 1".format(input_data, output_prefix), params)
    run("fslsplit {} {}_tmp".format(input_data, output_prefix), params)
    record_apptime(params, start_time, 2)
Beispiel #9
0
def s1_3_timeslice_process(params, worker_id, num_workers, inputs=[]):
    import time
    from subscripts.utilities import run,record_apptime
    from os.path import join,exists
    sdir = params['sdir']
    stdout = params['stdout']
    container = params['container']
    start_time = time.time()
    output_prefix = join(sdir,"data_eddy")
    timeslice = worker_id
    slice_data = join(sdir,"data_eddy_tmp{:04d}.nii.gz".format(timeslice))
    iteration = 0
    while exists(slice_data):
        # Break loop if it gets stuck
        if iteration > 99:
            break
        run("flirt -in {0} -ref {1}_ref -nosearch -interp trilinear -o {0} -paddingsize 1".format(slice_data, output_prefix), params)
        # Example: worker #3 with 10 total workers will process timeslices 3, 13, 23, 33...
        timeslice += num_workers
        slice_data = join(sdir,"data_eddy_tmp{:04d}.nii.gz".format(timeslice))
        iteration += 1
    record_apptime(params, start_time, 3)
Beispiel #10
0
def s3_3_combine(params, volumes, inputs=[]):
    import time
    from subscripts.utilities import run,record_apptime,record_finish,update_permissions,write
    from os.path import join,exists
    sdir = params['sdir']
    # volumes = params['volumes']
    start_time = time.time()
    # run_vtk = join(sdir, 'run_vtk.py')
    outdir = join(sdir, 'fast_outdir')
    allvoxelscortsubcort = join(sdir,"allvoxelscortsubcort.nii.gz")
    total = join(outdir, 'FAtractsumsTwoway.nii.gz')
    run("fslmaths {} -mul 0 {}".format(allvoxelscortsubcort, total), params)
    for vol in volumes:
        vol_outdir = join(outdir, vol)
        pbtx_result = join(vol_outdir, 'fdt_paths.nii.gz')
        run("fslmaths {} -thrP 5 -bin {}".format(pbtx_result, pbtx_result), params)
        run("fslmaths {} -add {} {}".format(pbtx_result, total, total), params)

        waytotal = join(vol_outdir, "waytotal")
    # run('/opt/vtk/bin/vtkpython {} {} {} {}'.format(run_vtk, total, total + '.png', 256), params)

    update_permissions(params)
    record_apptime(params, start_time, 2)
    record_finish(params)
Beispiel #11
0
def s2b_2_process_vols(params, inputs=[]):
    import time
    from subscripts.utilities import run, smart_mkdir, smart_remove, write, record_apptime, record_finish, update_permissions
    from subscripts.maskseeds import maskseeds, saveallvoxels
    from os.path import exists, join, split, splitext
    from os import environ
    from shutil import copy
    from glob import glob
    sdir = params['sdir']
    stdout = params['stdout']
    container = params['container']
    cores_per_task = params['cores_per_task']
    group = params['group']
    start_time = time.time()
    T1 = join(sdir, "T1.nii.gz")
    subject = split(sdir)[1]
    FA = join(sdir, "FA.nii.gz")
    aseg = join(sdir, "aseg.nii.gz")
    bs = join(sdir, "bs.nii.gz")
    FA2T1 = join(sdir, "FA2T1.mat")
    T12FA = join(sdir, "T12FA.mat")
    cort_label_dir = join(sdir, "label_cortical")
    cort_vol_dir = join(sdir, "volumes_cortical")
    cort_vol_dir_out = cort_vol_dir + "_s2fa"
    subcort_vol_dir = join(sdir, "volumes_subcortical")
    subcort_vol_dir_out = subcort_vol_dir + "_s2fa"
    terminationmask = join(sdir, "terminationmask.nii.gz")
    allvoxelscortsubcort = join(sdir, "allvoxelscortsubcort.nii.gz")
    intersection = join(sdir, "intersection.nii.gz")
    # exclusion_bsplusthalami = join(sdir,"exclusion_bsplusthalami.nii.gz")
    subcortical_index = [
        '10:lh_thalamus',
        '11:lh_caudate',
        '12:lh_putamen',
        '13:lh_pallidum',
        '17:lh_hippocampus',
        '18:lh_amygdala',
        '26:lh_acumbens',
        '49:rh_thalamus',
        '50:rh_caudate',
        '51:rh_putamen',
        '52:rh_pallidum',
        '53:rh_hippocampus',
        '54:rh_amygdala',
        '58:rh_acumbens',
    ]
    EDI = join(sdir, "EDI")
    EDI_allvols = join(EDI, "allvols")
    smart_mkdir(cort_label_dir)
    smart_mkdir(cort_vol_dir)
    smart_mkdir(subcort_vol_dir)
    smart_mkdir(cort_vol_dir_out)
    smart_mkdir(subcort_vol_dir_out)
    smart_mkdir(EDI)
    smart_mkdir(EDI_allvols)

    if not container:
        environ['SUBJECTS_DIR'] = split(sdir)[0]

    run("mri_convert {} {} ".format(join(sdir, "mri", "brain.mgz"), T1),
        params)
    run("flirt -in {} -ref {} -omat {}".format(FA, T1, FA2T1), params)
    run("convert_xfm -omat {} -inverse {}".format(T12FA, FA2T1), params)
    run(
        "mri_annotation2label --subject {} --hemi rh --annotation aparc --outdir {}"
        .format(subject, cort_label_dir), params)
    run(
        "mri_annotation2label --subject {} --hemi lh --annotation aparc --outdir {}"
        .format(subject, cort_label_dir), params)

    for label in glob(join(cort_label_dir, "*.label")):
        vol_file = join(cort_vol_dir, splitext(split(label)[1])[0] + ".nii.gz")
        run(
            "mri_label2vol --label {} --temp {} --identity --o {}".format(
                label, T1, vol_file), params)

    run("mri_convert {} {}".format(join(sdir, "mri", "aseg.mgz"), aseg),
        params)
    for line in subcortical_index:
        num = line.split(":")[0].lstrip().rstrip()
        area = line.split(":")[1].lstrip().rstrip()
        area_out = join(subcort_vol_dir, area + ".nii.gz")
        write(stdout, "Processing " + area + ".nii.gz")
        run(
            "fslmaths {} -uthr {} -thr {} -bin {}".format(
                aseg, num, num, area_out), params)

    for volume in glob(join(cort_vol_dir, "*.nii.gz")):
        out_vol = join(
            cort_vol_dir_out,
            splitext(splitext(split(volume)[1])[0])[0] + "_s2fa.nii.gz")
        write(
            stdout,
            "Processing {} -> {}".format(split(volume)[1],
                                         split(out_vol)[1]))
        run(
            "flirt -in {} -ref {} -out {}  -applyxfm -init {}".format(
                volume, FA, out_vol, T12FA), params)
        run("fslmaths {} -thr 0.2 -bin {} ".format(out_vol, out_vol), params)

    for volume in glob(join(subcort_vol_dir, "*.nii.gz")):
        out_vol = join(
            subcort_vol_dir_out,
            splitext(splitext(split(volume)[1])[0])[0] + "_s2fa.nii.gz")
        write(
            stdout,
            "Processing {} -> {}".format(split(volume)[1],
                                         split(out_vol)[1]))
        run(
            "flirt -in {} -ref {} -out {}  -applyxfm -init {}".format(
                volume, FA, out_vol, T12FA), params)
        run("fslmaths {} -thr 0.2 -bin {}".format(out_vol, out_vol), params)

    run("fslmaths {} -mul 0 {}".format(FA, bs),
        params)  # For now we fake a bs.nii.gz file
    maskseeds(sdir, join(cort_vol_dir + "_s2fa"),
              join(cort_vol_dir + "_s2fa_m"), 0.05, 1, 1, params)
    maskseeds(sdir, join(subcort_vol_dir + "_s2fa"),
              join(subcort_vol_dir + "_s2fa_m"), 0.05, 0.4, 0.4, params)
    saveallvoxels(sdir, join(cort_vol_dir + "_s2fa_m"),
                  join(subcort_vol_dir + "_s2fa_m"), allvoxelscortsubcort,
                  params)
    smart_remove(terminationmask)
    run("fslmaths {} -uthr .15 {}".format(FA, terminationmask), params)
    run("fslmaths {} -add {} {}".format(terminationmask, bs, terminationmask),
        params)
    run("fslmaths {} -bin {}".format(terminationmask, terminationmask), params)
    run(
        "fslmaths {} -mul {} {}".format(terminationmask, allvoxelscortsubcort,
                                        intersection), params)
    run(
        "fslmaths {} -sub {} {}".format(terminationmask, intersection,
                                        terminationmask), params)
    # run("fslmaths {} -add {} -add {} {}".format(bs,
    #                                             join(subcort_vol_dir + "_s2fa_m","lh_thalamus_s2fa.nii.gz"),
    #                                             join(subcort_vol_dir + "_s2fa_m","rh_thalamus_s2fa.nii.gz"),
    #                                             exclusion_bsplusthalami), params)
    for file in glob(join(sdir, "volumes_cortical_s2fa", "*.nii.gz")):
        copy(file, EDI_allvols)
    for file in glob(join(sdir, "volumes_subcortical_s2fa", "*.nii.gz")):
        copy(file, EDI_allvols)
    update_permissions(params)
    record_apptime(params, start_time, 2)
    record_finish(params)
Beispiel #12
0
def s3_2_probtrackx(params, vol, volumes, inputs=[]):
    import time
    from subscripts.utilities import run,smart_remove,smart_mkdir,write,record_start,record_apptime,sub_binary_vol
    from os.path import join,exists,split
    from shutil import copyfile
    start_time = time.time()
    sdir = params['sdir']
    use_gpu = params['use_gpu']
    pbtx_sample_count = int(params['pbtx_sample_count'])
    connectome_oneway = join(sdir, "connectome_oneway.dot")
    bedpostxResults = join(sdir,"bedpostx_b1000.bedpostX")
    merged = join(bedpostxResults,"merged")
    nodif_brain_mask = join(bedpostxResults,"nodif_brain_mask.nii.gz")
    outdir = join(sdir, 'fast_outdir')
    exclusion = join(outdir, "exclusion.nii.gz")
    termination = join(outdir, "termination.nii.gz")
    EDI_allvols = join(sdir,"EDI","allvols")

    vol_file = join(EDI_allvols, vol + "_s2fa.nii.gz")
    if not exists(vol_file):
        raise Exception('Failed to find volume {}'.format(vol_file))
    vol_outdir = join(outdir, vol)
    smart_remove(vol_outdir)
    smart_mkdir(vol_outdir)
    waypoints = join(vol_outdir, 'waypoint.txt')
    for vol2 in volumes:
        if vol != vol2:
            vol2_file = join(EDI_allvols, vol2 + "_s2fa.nii.gz")
            if not exists(vol2_file):
                raise Exception('Failed to find volume {}'.format(vol2_file))
            write(waypoints, vol2_file, params)
    vol_termination = join(vol_outdir, "vol_termination.nii.gz")
    vol_exclusion = join(vol_outdir, "vol_exclusion.nii.gz")
    copyfile(termination, vol_termination)
    copyfile(exclusion, vol_exclusion)
    sub_binary_vol(vol_file, vol_termination, params)
    sub_binary_vol(vol_file, vol_exclusion, params)
    vol_formatted = "fdt_paths.nii.gz"

    pbtx_args = (" -x {} ".format(vol_file) +
                " --pd -l -c 0.2 -S 2000 --steplength=0.5 -P {}".format(pbtx_sample_count) +
                " --waycond=OR --waypoints={}".format(waypoints) +
                " --os2t --s2tastext --targetmasks={}".format(waypoints) +
                " --stop={}".format(vol_termination) +
                " --avoid={}".format(vol_exclusion) +
                " --forcedir --opd" +
                " -s {}".format(merged) +
                " -m {}".format(nodif_brain_mask) +
                " --dir={}".format(vol_outdir) +
                " --out={}".format(vol_formatted)
                )
    if use_gpu:
        probtrackx2_sh = join(vol_outdir, "probtrackx2.sh")
        smart_remove(probtrackx2_sh)
        write(probtrackx2_sh, "export CUDA_LIB_DIR=$CUDA_8_LIB_DIR\n" +
                               "export LD_LIBRARY_PATH=$CUDA_LIB_DIR:$LD_LIBRARY_PATH\n" +
                               "probtrackx2_gpu" + pbtx_args, params)
        run("sh " + probtrackx2_sh, params)
    else:
        run("probtrackx2" + pbtx_args, params)

    # vol_connectome = join(vol_outdir,"connectome.dot")
    # waytotal = join(vol_outdir, "waytotal")
    # if not exists(waytotal):
    #     write(stdout, 'Error: failed to find waytotal for volume {}'.format(vol))

    # with open(waytotal, 'r') as f:
    #     waytotal_count = f.read().strip()
    #     fdt_count = run("fslmeants -i {} -m {} | head -n 1".format(join(vol_outdir, vol_formatted), vol_file), params)
    #     if not is_float(waytotal_count):
    #         raise Exception("Failed to read waytotal_count value {}".format(waytotal_count))
    #     if not is_float(fdt_count):
    #         raise Exception("Failed to read fdt_count value {}".format(fdt_count))
    #     write(vol_connectome, "{} {} {}".format(vol, waytotal_count, fdt_count))

    record_apptime(params, start_time, 1)
Beispiel #13
0
def s3_5_edi_combine(params, consensus_edges, inputs=[]):
    import time,tarfile
    from subscripts.utilities import run,smart_remove,smart_mkdir,write,record_apptime,record_finish, \
                                     update_permissions,get_edges_from_file,strip_trailing_slash
    from os.path import join,exists,basename
    from shutil import copyfile
    pbtx_edge_list = params['pbtx_edge_list']
    sdir = params['sdir']
    stdout = params['stdout']
    container = params['container']
    start_time = time.time()
    pbtk_dir = join(sdir,"EDI","PBTKresults")
    connectome_dir = join(sdir,"EDI","CNTMresults")
    compress_pbtx_results = params['compress_pbtx_results']
    consensus_dir = join(pbtk_dir,"twoway_consensus_edges")
    edi_maps = join(sdir,"EDI","EDImaps")
    edge_total = join(edi_maps,"FAtractsumsTwoway.nii.gz")
    tract_total = join(edi_maps,"FAtractsumsRaw.nii.gz")
    smart_remove(edi_maps)
    smart_mkdir(edi_maps)

    # Collect number of probtrackx tracts per voxel
    for edge in get_edges_from_file(pbtx_edge_list):
        a, b = edge
        a_to_b_formatted = "{}_s2fato{}_s2fa.nii.gz".format(a,b)
        a_to_b_file = join(pbtk_dir,a_to_b_formatted)
        if not exists(tract_total):
            copyfile(a_to_b_file, tract_total)
        else:
            run("fslmaths {0} -add {1} {1}".format(a_to_b_file, tract_total), params)

    # Collect number of parcel-to-parcel edges per voxel
    for edge in consensus_edges:
        a, b = edge
        consensus = join(consensus_dir, "{}_to_{}.nii.gz".format(a,b))
        if not exists(consensus):
            write(stdout,"{} has been thresholded. See {} for details".format(edge, join(pbtk_dir, "zerosl.txt")))
            continue
        if not exists(edge_total):
            copyfile(consensus, edge_total)
        else:
            run("fslmaths {0} -add {1} {1}".format(consensus, edge_total), params)
    if not exists(edge_total):
        write(stdout, "Error: Failed to generate {}".format(edge_total))

    if compress_pbtx_results:
        pbtk_archive = strip_trailing_slash(pbtk_dir) + '.tar.gz'
        connectome_archive = strip_trailing_slash(connectome_dir) + '.tar.gz'
        write(stdout,"\nCompressing probtrackx output at {} and {}".format(pbtk_archive, connectome_archive))
        smart_remove(pbtk_archive)
        smart_remove(connectome_archive)
        with tarfile.open(pbtk_archive, mode='w:gz') as archive:
            archive.add(pbtk_dir, recursive=True, arcname=basename(pbtk_dir))
        with tarfile.open(connectome_archive, mode='w:gz') as archive:
            archive.add(connectome_dir, recursive=True, arcname=basename(connectome_dir))
        smart_remove(pbtk_dir)
        smart_remove(connectome_dir)

    update_permissions(params)
    record_apptime(params, start_time, 4)
    record_finish(params)
Beispiel #14
0
def s3_3_combine(params, inputs=[]):
    import numpy as np
    import scipy.io
    import time
    from subscripts.utilities import record_apptime,record_finish,update_permissions,is_float,write,get_edges_from_file
    from os.path import join,exists
    from shutil import copyfile
    sdir = params['sdir']
    stdout = params['stdout']
    pbtx_sample_count = int(params['pbtx_sample_count'])
    pbtx_edge_list = params['pbtx_edge_list']
    connectome_idx_list = params['connectome_idx_list']
    connectome_idx_list_copy = join(sdir, 'connectome_idxs.txt')
    start_time = time.time()
    connectome_dir = join(sdir,"EDI","CNTMresults")
    oneway_list = join(sdir, "connectome_{}samples_oneway.txt".format(pbtx_sample_count))
    twoway_list = join(sdir, "connectome_{}samples_twoway.txt".format(pbtx_sample_count))
    oneway_nof = join(sdir, "connectome_{}samples_oneway_nof.mat".format(pbtx_sample_count)) # nof = number of fibers
    twoway_nof = join(sdir, "connectome_{}samples_twoway_nof.mat".format(pbtx_sample_count))
    oneway_nof_normalized = join(sdir, "connectome_{}samples_oneway_nofn.mat".format(pbtx_sample_count)) # nofn = number of fibers, normalized
    twoway_nof_normalized = join(sdir, "connectome_{}samples_twoway_nofn.mat".format(pbtx_sample_count))
    smart_remove(oneway_list)
    smart_remove(twoway_list)
    smart_remove(oneway_nof_normalized)
    smart_remove(twoway_nof_normalized)
    smart_remove(oneway_nof)
    smart_remove(twoway_nof)
    oneway_edges = {}
    twoway_edges = {}

    copyfile(connectome_idx_list, connectome_idx_list_copy) # give each subject a copy for reference

    vol_idxs = {}
    with open(connectome_idx_list) as f:
        lines = [x.strip() for x in f.readlines() if x]
        max_idx = -1
        for line in lines:
            vol, idx = line.split(',', 1)
            idx = int(idx)
            vol_idxs[vol] = idx
            if idx > max_idx:
                max_idx = idx
        oneway_nof_normalized_matrix = np.zeros((max_idx+1, max_idx+1))
        oneway_nof_matrix = np.zeros((max_idx+1, max_idx+1))
        twoway_nof_normalized_matrix = np.zeros((max_idx+1, max_idx+1))
        twoway_nof_matrix = np.zeros((max_idx+1, max_idx+1))

    for edge in get_edges_from_file(pbtx_edge_list):
        a, b = edge
        edge_file = join(connectome_dir, "{}_to_{}.dot".format(a, b))
        with open(edge_file) as f:
            chunks = [x.strip() for x in f.read().strip().split(' ') if x]
            a_to_b = (chunks[0], chunks[1])
            b_to_a = (chunks[1], chunks[0])
            waytotal_count = float(chunks[2])
            fdt_count = float(chunks[3])
            if b_to_a in twoway_edges:
                twoway_edges[b_to_a][0] += waytotal_count
                twoway_edges[b_to_a][1] += fdt_count
            else:
                twoway_edges[a_to_b] = [waytotal_count, fdt_count]
            oneway_edges[a_to_b] = [waytotal_count, fdt_count]

    for a_to_b in oneway_edges:
        a = a_to_b[0]
        b = a_to_b[1]
        for vol in a_to_b:
            if vol not in vol_idxs:
                write(stdout, 'Error: could not find {} in connectome idxs'.format(vol))
                break
        else:
            write(oneway_list, "{} {} {} {}".format(a, b, oneway_edges[a_to_b][0], oneway_edges[a_to_b][1]))
            oneway_nof_matrix[vol_idxs[a]][vol_idxs[b]] = oneway_edges[a_to_b][0]
            oneway_nof_normalized_matrix[vol_idxs[a]][vol_idxs[b]] = oneway_edges[a_to_b][1]

    for a_to_b in twoway_edges:
        a = a_to_b[0]
        b = a_to_b[1]
        for vol in a_to_b:
            if vol not in vol_idxs:
                write(stdout, 'Error: could not find {} in connectome idxs'.format(vol))
                break
        else:
            write(twoway_list, "{} {} {} {}".format(a, b, twoway_edges[a_to_b][0], twoway_edges[a_to_b][1]))
            twoway_nof_matrix[vol_idxs[a]][vol_idxs[b]] = twoway_edges[a_to_b][0]
            twoway_nof_normalized_matrix[vol_idxs[a]][vol_idxs[b]] = twoway_edges[a_to_b][1]
    scipy.io.savemat(oneway_nof, {'data': oneway_nof_matrix})
    scipy.io.savemat(oneway_nof_normalized, {'data': oneway_nof_normalized_matrix})
    scipy.io.savemat(twoway_nof, {'data': twoway_nof_matrix})
    scipy.io.savemat(twoway_nof_normalized, {'data': twoway_nof_normalized_matrix})

    update_permissions(params)
    record_apptime(params, start_time, 2)
    record_finish(params)
Beispiel #15
0
def s3_2_probtrackx(params, edges, inputs=[]):
    import time,platform,json,random,tempfile,os,fcntl,errno
    from subscripts.utilities import run,smart_remove,smart_mkdir,write,is_float,is_integer,record_start,record_apptime,is_integer
    from os.path import join,exists,split,dirname
    from shutil import copyfile
    sdir = params['sdir']
    odir = split(sdir)[0]
    derivatives_dir = params['derivatives_dir']
    derivatives_dir_tmp = join(derivatives_dir, "tmp")
    sdir_tmp = join(sdir, "tmp")
    stdout = params['stdout']
    container = params['container']
    use_gpu = params['use_gpu']
    pbtx_sample_count = int(params['pbtx_sample_count'])
    subject_random_seed = params['subject_random_seed']
    if use_gpu:
        pbtx_max_memory = float(params['pbtx_max_gpu_memory'])
    else:
        pbtx_max_memory = float(params['pbtx_max_memory'])
    EDI_allvols = join(sdir,"EDI","allvols")
    pbtk_dir = join(sdir,"EDI","PBTKresults")
    connectome_dir = join(sdir,"EDI","CNTMresults")
    bedpostxResults = join(sdir,"bedpostx_b1000.bedpostX")
    merged = join(bedpostxResults,"merged")
    nodif_brain_mask = join(bedpostxResults,"nodif_brain_mask.nii.gz")
    allvoxelscortsubcort = join(sdir,"allvoxelscortsubcort.nii.gz")
    terminationmask = join(sdir,"terminationmask.nii.gz")
    bs = join(sdir,"bs.nii.gz")

    assert exists(bedpostxResults), "Could not find {}".format(bedpostxResults)

    ### Memory Management ###

    node_name = platform.uname().node.strip()
    assert node_name and ' ' not in node_name, "Invalid node name {}".format(node_name)
    mem_record = join(derivatives_dir_tmp, node_name + '.json') # Keep record to avoid overusing node memory
    smart_mkdir(sdir_tmp)

    # Only access mem_record with file locking to avoid outdated data
    def open_mem_record(mode = 'r'):
        f = None
        while True:
            try:
                f = open(mem_record, mode, newline='')
                fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB)
                break
            except IOError as e:
                # raise on unrelated IOErrors
                if e.errno != errno.EAGAIN:
                    raise
                else:
                    time.sleep(0.1)
        assert f is not None, "Failed to open mem_record {}".format(mem_record)
        return f

    def estimate_total_memory_usage():
        f = open_mem_record('r')
        mem_dict = json.load(f)
        fcntl.flock(f, fcntl.LOCK_UN)
        f.close()
        mem_sum = 0.0
        for task_mem in mem_dict.values():
            mem_sum += float(task_mem)
        return mem_sum

    def estimate_task_mem_usage():
        total_size = 0
        total_size += os.path.getsize(allvoxelscortsubcort)
        total_size += os.path.getsize(terminationmask)
        total_size += os.path.getsize(bs)
        for dirpath, dirnames, filenames in os.walk(bedpostxResults):
            for f in filenames:
                fp = os.path.join(dirpath, f)
                if not os.path.islink(fp):
                    total_size += os.path.getsize(fp)

        max_region_size = 0
        for edge in edges:
            a, b = edge
            a_file = join(EDI_allvols, a + "_s2fa.nii.gz")
            b_file = join(EDI_allvols, b + "_s2fa.nii.gz")
            a_size = os.path.getsize(a_file)
            b_size = os.path.getsize(b_file)
            max_region_size = max([a_size, b_size, max_region_size])
        total_size += max_region_size
        return float(total_size) * 1.0E-9

    def add_task():
        task_id = '0'
        f = open_mem_record('r')
        if not exists(mem_record):
            json.dump({task_id:task_mem_usage}, f)
        else:
            mem_dict = json.load(f)
            task_ids = [int(x) for x in mem_dict.keys()] + [0] # append zero in case task_ids empty
            task_id = str(max(task_ids) + 1) # generate incremental task_id
            mem_dict[task_id] = task_mem_usage
            tmp_fp, tmp_path = tempfile.mkstemp(dir=sdir_tmp)
            with open(tmp_path, 'w', newline='') as tmp: # file pointer not consistent, so we open using the pathname
                json.dump(mem_dict, tmp)
            os.replace(tmp_path, mem_record) # atomic on POSIX systems. flock is advisory, so we can still overwrite.
        fcntl.flock(f, fcntl.LOCK_UN)
        f.close()
        return task_id

    def remove_task(task_id):
        f = open_mem_record('r')
        mem_dict = json.load(f)
        mem_dict.pop(task_id, None)
        tmp_fp, tmp_path = tempfile.mkstemp(dir=sdir_tmp)
        with open(tmp_path, 'w', newline='') as tmp:
            json.dump(mem_dict, tmp)
        os.replace(tmp_path, mem_record)
        fcntl.flock(f, fcntl.LOCK_UN)
        f.close()

    sleep_timeout = 7200

    if pbtx_max_memory > 0:
        task_mem_usage = estimate_task_mem_usage()
        total_sleep = 0
        # Memory record is atomic, but might not be updated on time
        # So we randomize sleep to discourage multiple tasks hitting at once
        init_sleep = random.randrange(0, 120)
        write(stdout, "Sleeping for {:d} seconds".format(init_sleep))
        total_sleep += init_sleep
        time.sleep(init_sleep)

        if not exists(mem_record):
            f = open_mem_record('w')
            json.dump({}, f)
            fcntl.flock(f, fcntl.LOCK_UN)
            f.close()

        total_mem_usage = estimate_total_memory_usage()
        # Then we sleep until memory usage is low enough
        while total_mem_usage + task_mem_usage > pbtx_max_memory:
            sleep_interval = random.randrange(30, 120)
            write(stdout, "Sleeping for {:d} seconds. Memory usage: {:.2f}/{:.2f} GB".format(sleep_interval, total_mem_usage, pbtx_max_memory))
            total_sleep += sleep_interval
            if total_sleep > sleep_timeout:
                raise Exception('Retrying task that has slept longer than 2 hours')
            time.sleep(sleep_interval)
            total_mem_usage = estimate_total_memory_usage()
        write(stdout, "Running Probtrackx after sleeping for {} seconds".format(total_sleep))

        # Insert task and memory usage into record
        task_id = add_task()

    ### Probtrackx Run ###
    start_time = time.time()
    try:
        for edge in edges:
            a, b = edge
            a_file = join(EDI_allvols, a + "_s2fa.nii.gz")
            b_file = join(EDI_allvols, b + "_s2fa.nii.gz")
            tmp = join(sdir, "tmp", "{}_to_{}".format(a, b))
            a_to_b_formatted = "{}_s2fato{}_s2fa.nii.gz".format(a,b)
            a_to_b_file = join(pbtk_dir,a_to_b_formatted)
            waypoints = join(tmp,"waypoint.txt")
            waytotal = join(tmp, "waytotal")
            assert exists(a_file) and exists(b_file), "Error: Both Freesurfer regions must exist: {} and {}".format(a_file, b_file)
            smart_remove(a_to_b_file)
            smart_remove(tmp)
            smart_mkdir(tmp)
            write(stdout, "Running subproc: {} to {}".format(a, b))
            if container:
                write(waypoints, b_file.replace(odir, "/share"))
            else:
                write(waypoints, b_file)

            exclusion = join(tmp,"exclusion.nii.gz")
            termination = join(tmp,"termination.nii.gz")
            run("fslmaths {} -sub {} {}".format(allvoxelscortsubcort, a_file, exclusion), params)
            run("fslmaths {} -sub {} {}".format(exclusion, b_file, exclusion), params)
            run("fslmaths {} -add {} {}".format(exclusion, bs, exclusion), params)
            run("fslmaths {} -add {} {}".format(terminationmask, b_file, termination), params)

            pbtx_args = (" -x {} ".format(a_file) +
                # " --pd -l -c 0.2 -S 2000 --steplength=0.5 -P 1000" +
                " --pd -l -c 0.2 -S 2000 --steplength=0.5 -P {}".format(pbtx_sample_count) +
                " --waypoints={} --avoid={} --stop={}".format(waypoints, exclusion, termination) +
                " --forcedir --opd --rseed={}".format(subject_random_seed) +
                " -s {}".format(merged) +
                " -m {}".format(nodif_brain_mask) +
                " --dir={}".format(tmp) +
                " --out={}".format(a_to_b_formatted)
                )
            if use_gpu:
                probtrackx2_sh = join(tmp, "probtrackx2.sh")
                smart_remove(probtrackx2_sh)
                write(probtrackx2_sh, "export CUDA_LIB_DIR=$CUDA_8_LIB_DIR\n" +
                               "export LD_LIBRARY_PATH=$CUDA_LIB_DIR:$LD_LIBRARY_PATH\n" +
                               "probtrackx2_gpu" + pbtx_args.replace(odir, "/share"))
                run("sh " + probtrackx2_sh, params)
            else:
                run("probtrackx2" + pbtx_args, params)

            waytotal_count = 0
            if exists(waytotal):
                with open(waytotal, 'r') as f:
                    waytotal_count = f.read().strip()
                    fdt_count = run("fslmeants -i {} -m {} | head -n 1".format(join(tmp, a_to_b_formatted), b_file), params) # based on getconnectome script
                    if not is_float(waytotal_count):
                        write(stdout, "Error: Failed to read waytotal_count value {} in {}".format(waytotal_count, edge))
                        continue
                    if not is_float(fdt_count):
                        write(stdout, "Error: Failed to read fdt_count value {} in {}".format(fdt_count, edge))
                        continue
                    edge_file = join(connectome_dir, "{}_to_{}.dot".format(a, b))
                    smart_remove(edge_file)
                    write(edge_file, "{} {} {} {}".format(a, b, waytotal_count, fdt_count))

                    # Error check edge file
                    with open(edge_file) as f:
                        line = f.read().strip()
                        if len(line) > 0: # ignore empty lines
                            chunks = [x.strip() for x in line.split(' ') if x]
                            if not (len(chunks) == 4 and is_float(chunks[2]) and is_float(chunks[3])):
                                write(stdout, "Error: Connectome {} has invalid edge {} to {}".format(edge_file, a, b))
                                continue
            else:
                write(stdout, 'Error: failed to find waytotal for {} to {}'.format(a, b))
            copyfile(join(tmp, a_to_b_formatted), a_to_b_file) # keep edi output
            if not a == "lh.paracentral": # discard all temp files except these for debugging
                smart_remove(tmp)
    finally:
        if pbtx_max_memory > 0:
            remove_task(task_id)

    record_apptime(params, start_time, 1)
Beispiel #16
0
def s1_1_dicom_preproc(params, inputs=[]):
    import time,tarfile
    from subscripts.utilities import run,record_apptime,record_start,smart_remove,smart_copy, \
                                     smart_mkdir,write,strip_trailing_slash
    from os.path import join,split,exists,basename
    from shutil import copyfile
    from glob import glob
    import numpy as np
    sdir = params['sdir']
    stdout = params['stdout']
    T1_dicom_dir = params['T1_dicom_dir']
    DTI_dicom_dir = params['DTI_dicom_dir']
    extra_b0_dirs = params['extra_b0_dirs']
    src_nifti_dir = params['src_nifti_dir']

    sourcedata_dir = params['sourcedata_dir']
    rawdata_dir = params['rawdata_dir']
    derivatives_dir = params['derivatives_dir']
    bids_dicom_dir = params['bids_dicom_dir']
    bids_nifti_dir = params['bids_nifti_dir']
    subject_name = params['subject_name']
    session_name = params['session_name']

    container = params['container']
    DTI_dicom_dir = params['DTI_dicom_dir']
    T1_dicom_dir = params['T1_dicom_dir']
    dicom_tmp_dir = join(sdir, 'tmp_dicom')

    smart_remove(dicom_tmp_dir)
    smart_mkdir(dicom_tmp_dir)
    
    smart_mkdir(join(bids_nifti_dir, "dwi"))
    smart_mkdir(join(bids_nifti_dir, "anat"))
    DTI_dicom_tmp_dir = join(dicom_tmp_dir, 'DTI')
    T1_dicom_tmp_dir = join(dicom_tmp_dir, 'T1')
    extra_b0_tmp_dirs = [join(dicom_tmp_dir, basename(dirname)) for dirname in extra_b0_dirs]

    hardi_file = join(bids_nifti_dir, "dwi", "{}_{}_dwi.nii.gz".format(subject_name, session_name))
    T1_file = join(bids_nifti_dir, "anat", "{}_{}_T1w.nii.gz".format(subject_name, session_name))
    bvals_file = join(bids_nifti_dir, "dwi", "{}_{}_dwi.bval".format(subject_name, session_name))
    bvecs_file = join(bids_nifti_dir, "dwi", "{}_{}_dwi.bvec".format(subject_name, session_name))

    start_time = time.time()
    record_start(params)

    if src_nifti_dir:
        smart_copy(join(src_nifti_dir, "hardi.nii.gz"), hardi_file)
        smart_copy(join(src_nifti_dir, "anat.nii.gz"), T1_file)
        smart_copy(join(src_nifti_dir, "bvals"), bvals_file)
        smart_copy(join(src_nifti_dir, "bvecs"), bvecs_file)
    elif T1_dicom_dir and DTI_dicom_dir:
        smart_remove(DTI_dicom_tmp_dir)
        smart_remove(T1_dicom_tmp_dir)

        # copy everything from DICOM dir except old NiFTI outputs
        smart_copy(T1_dicom_dir, T1_dicom_tmp_dir, ['*.nii', '*.nii.gz', '*.bval', '*.bvec'])
        write(stdout, 'Copied {} to {}'.format(T1_dicom_dir, T1_dicom_tmp_dir))
        smart_copy(DTI_dicom_dir, DTI_dicom_tmp_dir, ['*.nii', '*.nii.gz', '*.bval', '*.bvec'])
        write(stdout, 'Copied {} to {}'.format(DTI_dicom_dir, DTI_dicom_tmp_dir))
        for (extra_b0_dir, extra_b0_tmp_dir) in zip(extra_b0_dirs, extra_b0_tmp_dirs):
            smart_remove(extra_b0_tmp_dir)
            smart_copy(extra_b0_dir, extra_b0_tmp_dir, ['*.nii', '*.nii.gz', '*.bval', '*.bvec'])
            write(stdout, 'Copied {} to {}'.format(extra_b0_dir, extra_b0_tmp_dir))

        # Run dcm2nii in script to ensure Singularity container finds the right paths
        dicom_sh = join(sdir, "dicom.sh")
        smart_remove(dicom_sh)

        # Convert DTI dicom to many individual NiFTI files
        dicom_sh_contents = "dcm2nii -4 N"
        for file in glob(join(DTI_dicom_tmp_dir, '*.dcm')):
            dicom_sh_contents += " " + file

        for extra_b0_tmp_dir in extra_b0_tmp_dirs:
            dicom_sh_contents += "\ndcm2nii -4 N"
            for file in glob(join(extra_b0_tmp_dir, '*.dcm')):
                dicom_sh_contents += " " + file

        dicom_sh_contents += "\ndcm2nii -4 N"
        for file in glob(join(T1_dicom_tmp_dir, '*.dcm')):
            dicom_sh_contents += " " + file

        if container:
            odir = split(sdir)[0]
            write(dicom_sh, dicom_sh_contents.replace(odir, "/share"))
        else:
            write(dicom_sh, dicom_sh_contents)
        write(stdout, 'Running dcm2nii with script {}'.format(dicom_sh))
        run("sh " + dicom_sh, params)

        b0_slices = {}
        normal_slices = []
        all_slices = {}

        # Check that dcm2nii outputs exist
        found_bvals = glob(join(DTI_dicom_tmp_dir, '*.bval'))
        found_bvecs = glob(join(DTI_dicom_tmp_dir, '*.bvec'))
        found_T1 = glob(join(T1_dicom_tmp_dir, 'co*.nii.gz'))

        if len(found_bvals) != 1:
            raise Exception('Did not find exactly one bvals output in {}'.format(DTI_dicom_tmp_dir))
        else:
            copyfile(found_bvals[0], bvals_file)

        if len(found_bvecs) != 1:
            raise Exception('Did not find exactly one bvecs output in {}'.format(DTI_dicom_tmp_dir))
        else:
            copyfile(found_bvecs[0], bvecs_file)

        # If we don't find the usual T1 file name, just try any NifTI file in the T1 directory
        if len(found_T1) == 0:
            found_T1 = glob(join(T1_dicom_tmp_dir, '*.nii.gz'))
        if len(found_T1) == 0:
            raise Exception('Did not find T1 output in {}'.format(T1_dicom_tmp_dir))
        elif len(found_T1) > 1:
            write(stdout, 'Warning: Found more than one T1 output in {}'.format(T1_dicom_tmp_dir))
        found_T1.sort()
        copyfile(found_T1[0], T1_file)

        # Copy extra b0 values to DTI temp dir
        for extra_b0_tmp_dir in extra_b0_tmp_dirs:
            for file in glob(join(extra_b0_tmp_dir, "*.nii.gz")):
                copyfile(file, join(DTI_dicom_tmp_dir, "extra_b0_" + basename(file)))
            write(stdout, 'Copied NiFTI outputs from {} to {}'.format(extra_b0_tmp_dir, DTI_dicom_tmp_dir))

        # Sort slices into DTI and b0
        for file in glob(join(DTI_dicom_tmp_dir, '*.nii.gz')):
            slice_val = run("fslmeants -i {} | head -n 1".format(file), params) # based on getconnectome script
            all_slices[file] = float(slice_val)
        normal_median = np.median(list(all_slices.values()))
        for file in list(all_slices.keys()):
            slice_val = all_slices[file]
            # mark as b0 if more than 20% from normal slice median
            if abs(slice_val - normal_median) > 0.2 * normal_median:
                b0_slices[file] = slice_val
            else:
                normal_slices.append(file)
        if not b0_slices:
            raise Exception('Failed to find b0 values in {}'.format(DTI_dicom_dir))
        write(stdout, 'Found {} normal DTI slices'.format(len(normal_slices)))

        # Remove outliers from b0 values
        max_outliers = 1
        if len(b0_slices) > max_outliers:
            num_outliers = 0
            b0_median = np.median(list(b0_slices.values()))
            for file in list(b0_slices.keys()):
                slice_val = b0_slices[file]
                # remove outlier if more than 20% from b0 median
                if abs(slice_val - b0_median) > 0.2 * b0_median:
                    b0_slices.pop(file)
                    num_outliers += 1
            if num_outliers > max_outliers:
                raise Exception('Found more than {} outliers in b0 values. This probably means that this script has incorrectly identified b0 slices.'.format(max_outliers))
        write(stdout, 'Found {} b0 slices'.format(len(b0_slices)))

        # Average b0 slices into a single image
        avg_b0 = join(DTI_dicom_tmp_dir, 'avg_b0.nii.gz')
        smart_remove(avg_b0)
        for file in list(b0_slices.keys()):
            if not exists(avg_b0):
                copyfile(file, avg_b0)
            else:
                run("fslmaths {0} -add {1} {1}".format(file, avg_b0), params)
        run("fslmaths {0} -div {1} {0}".format(avg_b0, len(b0_slices)), params)

        # Concatenate average b0 and DTI slices into a single hardi.nii.gz
        normal_slices.sort()
        tmp_hardi = join(dicom_tmp_dir, "hardi.nii.gz")
        run("fslmerge -t {} {}".format(tmp_hardi, " ".join([avg_b0] + normal_slices)), params)
        copyfile(tmp_hardi, hardi_file)
        write(stdout, 'Concatenated b0 and DTI slices into {}'.format(hardi_file))

        # Clean extra zeroes from bvals and bvecs files
        num_slices = len(normal_slices)
        with open(bvals_file, 'r+') as f:
            entries = [x.strip() for x in f.read().split() if x]
            extra_zero = entries.pop(0) # strip leading zero
            if extra_zero != "0":
                raise Exception("{} should begin with zero, as a placeholder for the averaged b0 slice".format(bvals_file))

            # remove zero sequences
            min_sequence_length = 5
            if all(x == "0" for x in entries[0:min_sequence_length]):
                write(stdout, "Stripped leading zero sequence from {}".format(bvals_file))
                while len(entries) > num_slices:
                    extra_zero = entries.pop(0)
                    if extra_zero != "0":
                        raise Exception("Failed to clean extra zeros from {}".format(bvals_file))
            elif all(x == "0" for x in entries[-1:-min_sequence_length-1:-1]):
                write(stdout, "Stripped trailing zero sequence from {}".format(bvals_file))
                while len(entries) > num_slices:
                    extra_zero = entries.pop(-1)
                    if extra_zero != "0":
                        raise Exception("Failed to clean extra zeros from {}".format(bvals_file))

            if len(entries) > num_slices:
                raise Exception('Failed to clean bvals file {}. Since {} has {} slices, bvals must have {} columns'.
                    format(bvals_file, hardi_file, num_slices, num_slices))
            text = "0 " + " ".join(entries) + "\n" # restore leading zero
            f.seek(0)
            f.write(text)
            f.truncate()
            write(stdout, 'Generated bvals file with values:\n{}'.format(text))
        with open(bvecs_file, 'r+') as f:
            text = ""
            for line in f.readlines():
                if not line:
                    continue
                entries = [x.strip() for x in line.split() if x]
                extra_zero = entries.pop(0) # strip leading zero
                if extra_zero != "0":
                    raise Exception("Each line in {} should begin with zero, as a placeholder for the averaged b0 slice".format(bvecs_file))

                # remove zero sequences
                min_sequence_length = 5
                if all(x == "0" for x in entries[0:min_sequence_length]):
                    write(stdout, "Stripped leading zero sequence from {}".format(bvecs_file))
                    while len(entries) > num_slices:
                        extra_zero = entries.pop(0)
                        if extra_zero != "0":
                            raise Exception("Failed to clean extra zeros from {}".format(bvecs_file))
                elif all(x == "0" for x in entries[-1:-min_sequence_length-1:-1]):
                    write(stdout, "Stripped trailing zero sequence from {}".format(bvecs_file))
                    while len(entries) > num_slices:
                        extra_zero = entries.pop(-1)
                        if extra_zero != "0":
                            raise Exception("Failed to clean extra zeros from {}".format(bvecs_file))

                if len(entries) > num_slices:
                    raise Exception('Failed to clean bvecs file {}. Since {} has {} slices, bvecs must have {} columns'.
                        format(bvecs_file, hardi_file, num_slices, num_slices))
                text += "0 " + " ".join(entries) + "\n" # restore leading zero
            f.seek(0)
            f.write(text)
            f.truncate()
            write(stdout, 'Generated bvecs file with values:\n{}'.format(text))

        # Compress DICOM inputs
        dicom_tmp_archive = join(bids_dicom_dir, 'sourcedata.tar.gz')
        smart_remove(dicom_tmp_archive)
        with tarfile.open(dicom_tmp_archive, mode='w:gz') as archive:
            archive.add(dicom_tmp_dir, recursive=True, arcname=basename(dicom_tmp_dir))
        smart_remove(dicom_tmp_dir)
        write(stdout, 'Compressed temporary DICOM files to {}'.format(dicom_tmp_archive))

    smart_copy(hardi_file, join(sdir, "hardi.nii.gz"))
    smart_copy(T1_file, join(sdir,"T1.nii.gz"))
    smart_copy(bvecs_file, join(sdir,"bvecs"))
    smart_copy(bvals_file, join(sdir,"bvals"))
    record_apptime(params, start_time, 1)