Example #1
0
def test_streams2graph(fa_wei):
    from dipy.core.gradients import gradient_table
    from dipy.io import save_pickle

    base_dir = str(Path(__file__).parent/"examples")
    dwi_file = f"{base_dir}/003/test_out/003/dwi/sub-003_dwi_reor-RAS_res-2mm.nii.gz"
    conn_model = 'csd'
    min_length = 10
    error_margin = 2
    directget = 'prob'
    track_type = 'particle'
    target_samples = 500
    overlap_thr = 1
    min_span_tree = True
    prune = 3
    norm = 6
    binary = False
    roi = f"{base_dir}/miscellaneous/pDMN_3_bin.nii.gz"
    network = 'Default'
    ID = '003'
    parc = True
    disp_filt = False
    node_size = None
    dens_thresh = False
    atlas = 'whole_brain_cluster_labels_PCA200'
    uatlas = None
    coord_file_path = f"{base_dir}/miscellaneous/Default_func_coords_wb.pkl"
    coord_file = open(coord_file_path, 'rb')
    coords = pickle.load(coord_file)
    labels_file_path = f"{base_dir}/miscellaneous/Default_func_labelnames_wb.pkl"
    labels_file = open(labels_file_path, 'rb')
    labels = pickle.load(labels_file)
    # Not actually normalized to mni-space in this test.
    atlas_mni = f"{base_dir}/003/dmri/whole_brain_cluster_labels_PCA200_dwi_track.nii.gz"
    streams = f"{base_dir}/miscellaneous/003_streamlines_est-csd_nodetype-parc_samples-1000streams_tt-particle_dg-prob_ml-10.trk"
    B0_mask = f"{base_dir}/003/anat/mean_B0_bet_mask_tmp.nii.gz"
    dir_path = f"{base_dir}/003/dmri"
    bvals = f"{dir_path}/sub-003_dwi.bval"
    bvecs = f"{base_dir}/003/test_out/003/dwi/bvecs_reor.bvec"
    gtab_file = f"{base_dir}/gtab.pkl"
    gtab = gradient_table(bvals, bvecs)
    gtab.b0_threshold = 50
    gtab_bvals = gtab.bvals.copy()
    b0_thr_ixs = np.where(gtab_bvals < gtab.b0_threshold)[0]
    gtab_bvals[b0_thr_ixs] = 0
    gtab.b0s_mask = gtab_bvals == 0
    save_pickle(gtab_file, gtab)
    # Not actually normalized to mni-space in this test.
    warped_fa = tens_mod_fa_est(gtab_file, dwi_file, B0_mask)[0]

    conn_matrix = streams2graph(atlas_mni, streams, overlap_thr, dir_path, track_type, target_samples,
                                conn_model, network, node_size, dens_thresh, ID, roi, min_span_tree,
                                disp_filt, parc, prune, atlas, uatlas, labels, coords, norm, binary,
                                directget, warped_fa, error_margin, min_length, fa_wei)[2]
    assert conn_matrix is not None
Example #2
0
def test_streams2graph(fa_wei, dsn):
    from dipy.core.gradients import gradient_table
    from pynets.registration import register
    from pynets.core import nodemaker
    from dipy.io import save_pickle
    import os

    base_dir = str(Path(__file__).parent / "examples")
    dwi_file = f"{base_dir}/003/test_out/003/dwi/sub-003_dwi_reor-RAS_res-2mm.nii.gz"
    conn_model = 'csd'
    min_length = 10
    error_margin = 2
    directget = 'prob'
    track_type = 'particle'
    target_samples = 500
    overlap_thr = 1
    min_span_tree = True
    prune = 3
    norm = 6
    binary = False
    roi = f"{base_dir}/miscellaneous/pDMN_3_bin.nii.gz"
    network = 'Default'
    ID = '003'
    parc = True
    disp_filt = False
    node_size = None
    dens_thresh = False
    atlas = 'whole_brain_cluster_labels_PCA200'
    uatlas = f"{base_dir}/miscellaneous/whole_brain_cluster_labels_PCA200.nii.gz"
    t1_aligned_mni = f"{base_dir}/miscellaneous/whole_brain_cluster_labels_PCA200.nii.gz"
    atlas_dwi = f"{base_dir}/003/dmri/whole_brain_cluster_labels_PCA200_dwi_track.nii.gz"
    streams = f"{base_dir}/miscellaneous/003_streamlines_est-csd_nodetype-parc_samples-1000streams_tt-particle_" \
        f"dg-prob_ml-10.trk"
    B0_mask = f"{base_dir}/003/anat/mean_B0_bet_mask_tmp.nii.gz"
    dir_path = f"{base_dir}/003/dmri"
    bvals = f"{dir_path}/sub-003_dwi.bval"
    bvecs = f"{base_dir}/003/test_out/003/dwi/bvecs_reor.bvec"
    gtab_file = f"{base_dir}/gtab.pkl"
    gtab = gradient_table(bvals, bvecs)
    gtab.b0_threshold = 50
    gtab_bvals = gtab.bvals.copy()
    b0_thr_ixs = np.where(gtab_bvals < gtab.b0_threshold)[0]
    gtab_bvals[b0_thr_ixs] = 0
    gtab.b0s_mask = gtab_bvals == 0
    save_pickle(gtab_file, gtab)
    fa_path = tens_mod_fa_est(gtab_file, dwi_file, B0_mask)[0]

    coords = nodemaker.get_names_and_coords_of_parcels(uatlas)[0]
    labels = np.arange(len(coords) +
                       1)[np.arange(len(coords) + 1) != 0].tolist()

    # if dsn is True:
    #     os.makedirs(f"{dir_path}/dmri_reg/DSN", exist_ok=True)
    #     (streams_mni, dir_path, track_type, target_samples, conn_model, network, node_size, dens_thresh, ID, roi,
    #      min_span_tree, disp_filt, parc, prune, atlas, uatlas, labels, coords, norm, binary, atlas_mni, directget,
    #      warped_fa, min_length, error_margin) = register.direct_streamline_norm(streams, fa_path, fa_path, dir_path,
    #                                                                             track_type, target_samples, conn_model,
    #                                                                             network, node_size, dens_thresh, ID,
    #                                                                             roi, min_span_tree, disp_filt, parc,
    #                                                                             prune, atlas, atlas_dwi, uatlas,
    #                                                                             labels, coords, norm, binary, uatlas,
    #                                                                             dir_path, [0.1, 0.2], [40, 30],
    #                                                                             directget, min_length, t1_aligned_mni,
    #                                                                             error_margin)
    #
    #     conn_matrix = streams2graph(atlas_mni, streams_mni, overlap_thr, dir_path, track_type, target_samples,
    #                                 conn_model, network, node_size, dens_thresh, ID, roi, min_span_tree,
    #                                 disp_filt, parc, prune, atlas, uatlas, labels, coords, norm, binary,
    #                                 directget, warped_fa, error_margin, min_length, fa_wei)[2]
    # else:
    #     conn_matrix = streams2graph(atlas_dwi, streams, overlap_thr, dir_path, track_type, target_samples,
    #                                 conn_model, network, node_size, dens_thresh, ID, roi, min_span_tree,
    #                                 disp_filt, parc, prune, atlas, atlas_dwi, labels, coords, norm, binary,
    #                                 directget, fa_path, error_margin, min_length, fa_wei)[2]

    conn_matrix = streams2graph(
        atlas_dwi, streams, overlap_thr, dir_path, track_type, target_samples,
        conn_model, network, node_size, dens_thresh, ID, roi, min_span_tree,
        disp_filt, parc, prune, atlas, atlas_dwi, labels, coords, norm, binary,
        directget, fa_path, error_margin, min_length, fa_wei)[2]

    assert conn_matrix is not None
Example #3
0
def test_streams2graph(dmri_estimation_data,
                       tractography_estimation_data,
                       random_mni_roi_data,
                       dsn=False):
    from pynets.registration.register import direct_streamline_norm
    from dipy.core.gradients import gradient_table
    from dipy.io import save_pickle
    import random
    import pkg_resources

    tmp = tempfile.TemporaryDirectory()
    dir_path = str(tmp.name)
    os.makedirs(dir_path, exist_ok=True)

    dwi_file = dmri_estimation_data['dwi_file']
    conn_model = 'csd'
    min_length = 10
    error_margin = 5
    traversal = 'prob'
    track_type = 'local'
    min_span_tree = False
    prune = 3
    norm = 6
    binary = False
    roi = random_mni_roi_data['roi_file']
    subnet = 'Default'
    ID = '003'
    parc = True
    disp_filt = False
    node_radius = None
    dens_thresh = False
    atlas = 'whole_brain_cluster_labels_PCA200'
    parcellation = pkg_resources.resource_filename(
        "pynets", "templates/atlases/whole_brain_cluster_labels_PCA200.nii.gz")
    streams = tractography_estimation_data['trk']
    B0_mask = dmri_estimation_data['B0_mask']
    bvals = dmri_estimation_data['fbvals']
    bvecs = dmri_estimation_data['fbvecs']
    gtab_file = dmri_estimation_data['gtab_file']
    gtab = gradient_table(bvals, bvecs)
    gtab.b0_threshold = 50
    gtab_bvals = gtab.bvals.copy()
    b0_thr_ixs = np.where(gtab_bvals < gtab.b0_threshold)[0]
    gtab_bvals[b0_thr_ixs] = 0
    gtab.b0s_mask = gtab_bvals == 0
    save_pickle(gtab_file, gtab)
    fa_path = tens_mod_fa_est(gtab_file, dwi_file, B0_mask)[0]
    ap_path = create_anisopowermap(gtab_file, dwi_file, B0_mask)[0]
    t1w_brain = dmri_estimation_data['t1w_file']
    t1w_gm = dmri_estimation_data['f_pve_gm']
    atlas_in_dwi = fname_presuffix(ap_path,
                                   suffix="atlas_in_dwi",
                                   use_ext=True)
    resample_to_img(nib.load(parcellation),
                    nib.load(fa_path),
                    interpolation="nearest").to_filename(atlas_in_dwi)

    coords = [
        (random.random() * 2.0, random.random() * 2.0, random.random() * 2.0)
        for _ in range(len(np.unique(nib.load(atlas_in_dwi).get_fdata())) - 1)
    ]
    labels = np.arange(len(coords) +
                       1)[np.arange(len(coords) + 1) != 0].tolist()

    if dsn is True:
        os.makedirs(f"{dir_path}/dmri_reg/DSN", exist_ok=True)
        (streams, dir_path, track_type, conn_model, subnet, node_radius,
         dens_thresh, ID, roi, min_span_tree, disp_filt, parc, prune, atlas,
         parcellation, labels, coords, norm, binary, atlas_for_streams,
         traversal, fa_path,
         min_length) = direct_streamline_norm(streams,
                                              fa_path,
                                              ap_path,
                                              dir_path,
                                              track_type,
                                              conn_model,
                                              subnet,
                                              node_radius,
                                              dens_thresh,
                                              ID,
                                              roi,
                                              min_span_tree,
                                              disp_filt,
                                              parc,
                                              prune,
                                              atlas,
                                              atlas_in_dwi,
                                              parcellation,
                                              labels,
                                              coords,
                                              norm,
                                              binary,
                                              t1w_gm,
                                              dir_path, [0.1, 0.2], [40, 30],
                                              traversal,
                                              min_length,
                                              t1w_brain,
                                              run_dsn=True)

    conn_matrix = streams2graph(
        atlas_in_dwi, streams, os.path.dirname(streams), track_type,
        conn_model, subnet, node_radius, dens_thresh, ID, roi, min_span_tree,
        disp_filt, parc, prune, atlas, atlas_in_dwi, labels, coords, norm,
        binary, traversal, fa_path, min_length, error_margin)[2]

    assert conn_matrix is not None