コード例 #1
0
ファイル: plotting.py プロジェクト: kaczmarj/PyNets
def plot_all(conn_matrix, conn_model, atlas_select, dir_path, ID, network, label_names, mask, coords, edge_threshold, plot_switch):
    from nilearn import plotting as niplot
    pruning=True
    dpi_resolution=1000
    if plot_switch == True:
        import pkg_resources
        import networkx as nx
        from pynets import plotting
        import matplotlib.pyplot as plt
        from pynets.netstats import most_important
        G_pre=nx.from_numpy_matrix(conn_matrix)
        if pruning == True:
            [G, pruned_nodes, pruned_edges] = most_important(G_pre)
        else:
            G = G_pre
        conn_matrix = nx.to_numpy_array(G)
        
        pruned_nodes.sort(reverse = True)
        for j in pruned_nodes:
            del label_names[label_names.index(label_names[j])]
            del coords[coords.index(coords[j])]
        
        pruned_edges.sort(reverse = True)
        for j in pruned_edges:
            del label_names[label_names.index(label_names[j])]
            del coords[coords.index(coords[j])]
        
        ##Plot connectogram
        if len(conn_matrix) > 20:
            try:
                plotting.plot_connectogram(conn_matrix, conn_model, atlas_select, dir_path, ID, network, label_names)
            except RuntimeError:
                print('\n\n\nError: Connectogram plotting failed!')
        else:
            print('Error: Cannot plot connectogram for graphs smaller than 20 x 20!')
    
        ##Plot adj. matrix based on determined inputs
        plotting.plot_conn_mat(conn_matrix, conn_model, atlas_select, dir_path, ID, network, label_names, mask)
    
        ##Plot connectome
        if mask != None:
            if network != 'None':
                out_path_fig=dir_path + '/' + ID + '_' + atlas_select + '_' + str(conn_model) + '_' + str(os.path.basename(mask).split('.')[0]) + '_' + str(network) + '_connectome_viz.png'
            else:
                out_path_fig=dir_path + '/' + ID + '_' + atlas_select + '_' + str(conn_model) + '_' + str(os.path.basename(mask).split('.')[0]) + '_connectome_viz.png'
        else:
            if network != 'None':
                out_path_fig=dir_path + '/' + ID + '_' + atlas_select + '_' + str(conn_model) + '_' + str(network) + '_connectome_viz.png'
            else:
                out_path_fig=dir_path + '/' + ID + '_' + atlas_select + '_' + str(conn_model) + '_connectome_viz.png'
        #niplot.plot_connectome(conn_matrix, coords, edge_threshold=edge_threshold, node_size=20, colorbar=True, output_file=out_path_fig)
        ch2better_loc = pkg_resources.resource_filename("pynets", "templates/ch2better.nii.gz")
        connectome = niplot.plot_connectome(np.zeros(shape=(1,1)), [(0,0,0)], black_bg=True, node_size=0.0001)
        connectome.add_overlay(ch2better_loc, alpha=0.4, cmap=plt.cm.gray)
        [z_min, z_max] = -np.abs(conn_matrix).max(), np.abs(conn_matrix).max()
        connectome.add_graph(conn_matrix, coords, edge_threshold = edge_threshold, edge_cmap = 'Blues', edge_vmax=z_max, edge_vmin=z_min, node_size=4)
        connectome.savefig(out_path_fig, dpi=dpi_resolution)
    else:
        pass
    return
コード例 #2
0
def test_plot_connectogram():
    ##Set example inputs##
    base_dir = str(Path(__file__).parent/"examples")
    #base_dir = '/Users/PSYC-dap3463/Applications/PyNets/tests/examples'
    dir_path= base_dir + '/997'
    network=None
    ID = '997'
    conn_model = 'sps'
    atlas_select = 'whole_brain_cluster_labels_PCA200'
    conn_matrix = np.genfromtxt(dir_path + '/whole_brain_cluster_labels_PCA200/997_est_sps_0.94.txt')
    labels_file_path = dir_path + '/whole_brain_cluster_labels_PCA200/WB_func_labelnames_wb.pkl'
    labels_file = open(labels_file_path,'rb')
    label_names = pickle.load(labels_file)
    
    plotting.plot_connectogram(conn_matrix, conn_model, atlas_select, dir_path, ID, network, label_names)
コード例 #3
0
ファイル: test_plotting.py プロジェクト: lqcheng2017/PyNets
def test_plot_connectogram():
    # Set example inputs
    base_dir = str(Path(__file__).parent / "examples")
    #base_dir = '/Users/rxh180012/PyNets-development/tests/examples'
    dir_path = base_dir + '/997'
    network = None
    ID = '997'
    conn_model = 'sps'
    atlas_select = 'whole_brain_cluster_labels_PCA200'
    conn_matrix = np.genfromtxt(
        dir_path +
        '/whole_brain_cluster_labels_PCA200/997_Default_est_sps_0.94.txt')
    labels_file_path = dir_path + '/whole_brain_cluster_labels_PCA200/Default_func_labelnames_wb.pkl'
    labels_file = open(labels_file_path, 'rb')
    label_names = pickle.load(labels_file)

    start_time = time.time()
    plotting.plot_connectogram(conn_matrix, conn_model, atlas_select, dir_path,
                               ID, network, label_names)
    print("%s%s%s" % ('plot_connectogram --> finished: ',
                      str(np.round(time.time() - start_time, 1)), 's'))
コード例 #4
0
ファイル: plotting.py プロジェクト: lqcheng2017/PyNets
def plot_all(conn_matrix, conn_model, atlas_select, dir_path, ID, network,
             label_names, mask, coords, thr, node_size, edge_threshold, smooth,
             prune, uatlas_select):
    import matplotlib
    matplotlib.use('agg')
    from matplotlib import pyplot as plt
    from nilearn import plotting as niplot
    import pkg_resources
    import networkx as nx
    from pynets import plotting, thresholding
    from pynets.netstats import most_important, prune_disconnected
    try:
        import cPickle as pickle
    except ImportError:
        import _pickle as pickle

    coords = list(coords)
    label_names = list(label_names)

    dpi_resolution = 500
    if '\'b' in atlas_select:
        atlas_select = atlas_select.decode('utf-8')
    if (prune == 1 or prune == 2) and len(coords) == conn_matrix.shape[0]:
        G_pre = nx.from_numpy_matrix(conn_matrix)
        if prune == 1:
            [G, pruned_nodes] = prune_disconnected(G_pre)
        elif prune == 2:
            [G, pruned_nodes] = most_important(G_pre)
        else:
            G = G_pre
            pruned_nodes = []
        pruned_nodes.sort(reverse=True)
        print('(Display)')
        coords_pre = list(coords)
        label_names_pre = list(label_names)
        if len(pruned_nodes) > 0:
            for j in pruned_nodes:
                label_names_pre.pop(j)
                coords_pre.pop(j)
            conn_matrix = nx.to_numpy_array(G)
            label_names = label_names_pre
            coords = coords_pre
        else:
            print('No nodes to prune for plot...')

    coords = list(tuple(x) for x in coords)
    # Plot connectogram
    if len(conn_matrix) > 20:
        try:
            plotting.plot_connectogram(conn_matrix, conn_model, atlas_select,
                                       dir_path, ID, network, label_names)
        except:
            print('\n\n\nWarning: Connectogram plotting failed!')
    else:
        print(
            'Warning: Cannot plot connectogram for graphs smaller than 20 x 20!'
        )

    # Plot adj. matrix based on determined inputs
    if not node_size or node_size == 'None':
        node_size = 'parc'
    plotting.plot_conn_mat_func(conn_matrix, conn_model, atlas_select,
                                dir_path, ID, network, label_names, mask, thr,
                                node_size, smooth)

    # Plot connectome
    if mask:
        out_path_fig = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            dir_path, '/', ID, '_', str(atlas_select), '_', str(conn_model),
            '_', str(os.path.basename(mask).split('.')[0]), "%s" %
            ("%s%s%s" % ('_', network, '_') if network else "_"), str(thr),
            '_', str(node_size), '%s' %
            ("mm_" if node_size != 'parc' else "_"), "%s" %
            ("%s%s" % (smooth, 'fwhm_') if float(smooth) > 0 else 'nosm_'),
            'func_glass_viz.png')
        # Save coords to pickle
        coord_path = "%s%s%s%s" % (dir_path, '/coords_',
                                   os.path.basename(mask).split('.')[0],
                                   '_plotting.pkl')
        with open(coord_path, 'wb') as f:
            pickle.dump(coords, f, protocol=2)
        # Save labels to pickle
        labels_path = "%s%s%s%s" % (dir_path, '/labelnames_',
                                    os.path.basename(mask).split('.')[0],
                                    '_plotting.pkl')
        with open(labels_path, 'wb') as f:
            pickle.dump(label_names, f, protocol=2)
    else:
        out_path_fig = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            dir_path, '/', ID, '_', str(atlas_select), '_', str(conn_model),
            "%s" % ("%s%s%s" % ('_', network, '_') if network else "_"),
            str(thr), '_', str(node_size), '%s' %
            ("mm_" if node_size != 'parc' else "_"), "%s" %
            ("%s%s" % (smooth, 'fwhm_') if float(smooth) > 0 else 'nosm_'),
            'func_glass_viz.png')
        # Save coords to pickle
        coord_path = "%s%s" % (dir_path, '/coords_plotting.pkl')
        with open(coord_path, 'wb') as f:
            pickle.dump(coords, f, protocol=2)
        # Save labels to pickle
        labels_path = "%s%s" % (dir_path, '/labelnames_plotting.pkl')
        with open(labels_path, 'wb') as f:
            pickle.dump(label_names, f, protocol=2)

    ch2better_loc = pkg_resources.resource_filename(
        "pynets", "templates/ch2better.nii.gz")
    connectome = niplot.plot_connectome(np.zeros(shape=(1, 1)), [(0, 0, 0)],
                                        node_size=0.0001,
                                        black_bg=True)
    connectome.add_overlay(ch2better_loc, alpha=0.35, cmap=plt.cm.gray)
    conn_matrix = np.array(np.array(thresholding.autofix(conn_matrix)))
    [z_min, z_max] = -np.abs(conn_matrix).max(), np.abs(conn_matrix).max()
    if node_size == 'parc':
        node_size_plot = int(2)
        if uatlas_select:
            connectome.add_contours(uatlas_select,
                                    filled=False,
                                    alpha=0.3,
                                    colors='black')
    else:
        node_size_plot = int(node_size)
    if len(coords) != conn_matrix.shape[0]:
        raise RuntimeWarning(
            'WARNING: Number of coordinates does not match conn_matrix dimensions. If you are using disparity filtering, try relaxing the α threshold.'
        )
    else:
        connectome.add_graph(conn_matrix,
                             coords,
                             edge_threshold=edge_threshold,
                             edge_cmap='Blues',
                             edge_vmax=float(z_max),
                             edge_vmin=float(z_min),
                             node_size=node_size_plot,
                             node_color='auto')
        connectome.savefig(out_path_fig, dpi=dpi_resolution)
    return
コード例 #5
0
def run_struct_mapping(FSLDIR, ID, bedpostx_dir, dir_path, NETWORK, coords_MNI,
                       node_size, atlas_select, atlas_name, label_names,
                       plot_switch):
    edge_threshold = 0.90
    connectome_fdt_thresh = 1000

    ####Auto-set INPUTS####
    nodif_brain_mask_path = bedpostx_dir + '/nodif_brain_mask.nii.gz'
    merged_th_samples_path = bedpostx_dir + '/merged_th1samples.nii.gz'
    merged_f_samples_path = bedpostx_dir + '/merged_f1samples.nii.gz'
    merged_ph_samples_path = bedpostx_dir + '/merged_ph1samples.nii.gz'
    input_MNI = FSLDIR + '/data/standard/MNI152_T1_2mm_brain.nii.gz'
    probtrackx_output_dir_path = bedpostx_dir + '/probtrackx_' + NETWORK
    ####Auto-set INPUTS####

    ##Delete any existing roi spheres
    del_files_spheres = glob.glob(bedpostx_dir + '/roi_sphere*diff.nii.gz')
    try:
        for i in del_files_spheres:
            os.remove(i)
    except:
        pass

    ##Create transform matrix between diff and MNI using FLIRT
    flirt = pe.Node(interface=fsl.FLIRT(cost_func='mutualinfo'),
                    name='coregister')
    flirt.inputs.reference = merged_f_samples_path
    flirt.inputs.in_file = input_MNI
    flirt.inputs.out_matrix_file = bedpostx_dir + '/xfms/MNI2diff.mat'
    flirt.run()

    ##Apply transform between diff and MNI using FLIRT
    flirt = pe.Node(interface=fsl.FLIRT(cost_func='mutualinfo'),
                    name='coregister')
    flirt.inputs.reference = merged_f_samples_path
    flirt.inputs.in_file = input_MNI
    flirt.inputs.apply_xfm = True
    flirt.inputs.in_matrix_file = bedpostx_dir + '/xfms/MNI2diff.mat'
    flirt.inputs.out_file = bedpostx_dir + '/xfms/MNI2diff_affine.nii.gz'
    flirt.run()

    x_vox = np.diagonal(
        masking._load_mask_img(nodif_brain_mask_path)[1][:3, 0:3])[0]
    y_vox = np.diagonal(
        masking._load_mask_img(nodif_brain_mask_path)[1][:3, 0:3])[1]
    z_vox = np.diagonal(
        masking._load_mask_img(nodif_brain_mask_path)[1][:3, 0:3])[2]

    def mmToVox(mmcoords):
        voxcoords = ['', '', '']
        voxcoords[0] = int((round(int(mmcoords[0]) / x_vox)) + 45)
        voxcoords[1] = int((round(int(mmcoords[1]) / y_vox)) + 63)
        voxcoords[2] = int((round(int(mmcoords[2]) / z_vox)) + 36)
        return voxcoords

    ##Convert coords back to voxels
    coords_vox = []
    for coord in coords_MNI:
        coords_vox.append(mmToVox(coord))
    coords = list(tuple(x) for x in coords_vox)

    j = 0
    for i in coords:
        ##Grow spheres at ROI
        X = coords[j][0]
        Y = coords[j][1]
        Z = coords[j][2]
        out_file1 = bedpostx_dir + '/roi_point_' + str(j) + '.nii.gz'
        args = '-mul 0 -add 1 -roi ' + str(X) + ' 1 ' + str(Y) + ' 1 ' + str(
            Z) + ' 1 0 1'
        maths = fsl.ImageMaths(in_file=input_MNI,
                               op_string=args,
                               out_file=out_file1)
        os.system(maths.cmdline + ' -odt float')

        out_file2 = bedpostx_dir + '/roi_sphere_' + str(j) + '.nii.gz'
        args = '-kernel sphere ' + str(node_size) + ' -fmean -bin'
        maths = fsl.ImageMaths(in_file=out_file1,
                               op_string=args,
                               out_file=out_file2)
        os.system(maths.cmdline + ' -odt float')

        ##Map ROIs from Standard Space to diffusion Space:
        ##Applying xfm and input matrix to transform ROI's between diff and MNI using FLIRT,
        flirt = pe.Node(interface=fsl.FLIRT(cost_func='mutualinfo'),
                        name='coregister')
        flirt.inputs.reference = nodif_brain_mask_path
        flirt.inputs.in_file = out_file2
        out_file_diff = out_file2.split('.nii')[0] + '_diff.nii.gz'
        flirt.inputs.out_file = out_file_diff
        flirt.inputs.apply_xfm = True
        flirt.inputs.in_matrix_file = bedpostx_dir + '/xfms/MNI2diff.mat'
        flirt.run()
        j = j + 1

    if not os.path.exists(probtrackx_output_dir_path):
        os.makedirs(probtrackx_output_dir_path)

    seed_files = glob.glob(bedpostx_dir + '/*diff.nii.gz')
    seeds_text = probtrackx_output_dir_path + '/masks.txt'
    try:
        os.remove(seeds_text)
    except OSError:
        pass
    seeds_file_list = []
    for seed_file in seed_files:
        seeds_file_list.append(seed_file)
    f = open(seeds_text, 'w')
    l1 = map(lambda x: x + '\n', seeds_file_list)
    f.writelines(l1)
    f.close()

    del_files_points = glob.glob(bedpostx_dir + '/roi_point*.nii.gz')
    for i in del_files_points:
        os.remove(i)

    del_files_spheres = glob.glob(bedpostx_dir + '/roi_sphere*[!diff].nii.gz')
    for i in del_files_spheres:
        os.remove(i)

    mx_path = dir_path + '/' + str(ID) + '_' + NETWORK + '_structural_mx.txt'
    probtrackx2 = pe.Node(interface=fsl.ProbTrackX2(), name='probtrackx2')
    probtrackx2.inputs.network = True
    probtrackx2.inputs.seed = seeds_text
    probtrackx2.inputs.onewaycondition = True
    probtrackx2.inputs.c_thresh = 0.2
    probtrackx2.inputs.n_steps = 2000
    probtrackx2.inputs.step_length = 0.5
    probtrackx2.inputs.n_samples = 5000
    probtrackx2.inputs.dist_thresh = 0.0
    probtrackx2.inputs.opd = True
    probtrackx2.inputs.loop_check = True
    probtrackx2.inputs.omatrix1 = True
    probtrackx2.overwrite = True
    probtrackx2.inputs.verbose = True
    probtrackx2.inputs.mask = nodif_brain_mask_path
    probtrackx2.inputs.out_dir = probtrackx_output_dir_path
    probtrackx2.inputs.thsamples = merged_th_samples_path
    probtrackx2.inputs.fsamples = merged_f_samples_path
    probtrackx2.inputs.phsamples = merged_ph_samples_path
    probtrackx2.iterables = ("seed", seed_files)
    try:
        probtrackx2.inputs.avoid_mp = vetricular_CSF_mask_path
    except:
        pass
    probtrackx2.run()
    del (probtrackx2)

    if os.path.exists(probtrackx_output_dir_path + '/fdt_network_matrix'):
        mx = np.genfromtxt(probtrackx_output_dir_path + '/fdt_network_matrix')

        waytotal = np.genfromtxt(probtrackx_output_dir_path + '/waytotal')
        np.seterr(divide='ignore', invalid='ignore')
        conn_matrix = np.divide(mx, waytotal)
        conn_matrix[np.isnan(conn_matrix)] = 0
        conn_matrix = np.nan_to_num(conn_matrix)
        conn_matrix = normalize(conn_matrix)

        ##Save matrix
        out_path_mx = dir_path + '/' + str(
            ID) + '_' + NETWORK + '_structural_mx.txt'
        np.savetxt(out_path_mx, conn_matrix, delimiter='\t')

        if plot_switch == True:
            rois_num = conn_matrix.shape[0]
            print("Creating plot of dimensions:\n" + str(rois_num) + ' x ' +
                  str(rois_num))
            plt.figure(figsize=(10, 10))
            plt.imshow(conn_matrix,
                       interpolation="nearest",
                       vmax=1,
                       vmin=-1,
                       cmap=plt.cm.RdBu_r)

            ##And display the labels
            plt.colorbar()
            plt.title(atlas_select.upper() + ' ' + NETWORK +
                      ' Structural Connectivity')

            out_path_fig = dir_path + '/' + str(
                ID) + '_' + NETWORK + '_structural_adj_mat.png'
            plt.savefig(out_path_fig)
            plt.close()

            conn_matrix_symm = np.maximum(conn_matrix, conn_matrix.transpose())

        fdt_paths_loc = probtrackx_output_dir_path + '/fdt_paths.nii.gz'

        ##Plotting with glass brain
        ##Create transform matrix between diff and MNI using FLIRT
        flirt = pe.Node(interface=fsl.FLIRT(cost_func='mutualinfo'),
                        name='coregister')
        flirt.inputs.reference = input_MNI
        flirt.inputs.in_file = nodif_brain_mask_path
        flirt.inputs.out_matrix_file = bedpostx_dir + '/xfms/diff2MNI.mat'
        flirt.run()

        ##Apply transform between diff and MNI using FLIRT
        flirt = pe.Node(interface=fsl.FLIRT(cost_func='mutualinfo'),
                        name='coregister')
        flirt.inputs.reference = input_MNI
        flirt.inputs.in_file = nodif_brain_mask_path
        flirt.inputs.apply_xfm = True
        flirt.inputs.in_matrix_file = bedpostx_dir + '/xfms/diff2MNI.mat'
        flirt.inputs.out_file = bedpostx_dir + '/xfms/diff2MNI_affine.nii.gz'
        flirt.run()

        flirt = pe.Node(interface=fsl.FLIRT(cost_func='mutualinfo'),
                        name='coregister')
        flirt.inputs.reference = input_MNI
        flirt.inputs.in_file = fdt_paths_loc
        out_file_MNI = fdt_paths_loc.split('.nii')[0] + '_MNI.nii.gz'
        flirt.inputs.out_file = out_file_MNI
        flirt.inputs.apply_xfm = True
        flirt.inputs.in_matrix_file = bedpostx_dir + '/xfms/diff2MNI.mat'
        flirt.run()

        fdt_paths_MNI_loc = probtrackx_output_dir_path + '/fdt_paths_MNI.nii.gz'

        if plot_switch == True:
            norm = colors.Normalize(vmin=-1, vmax=1)
            clust_pal = sns.color_palette("Blues_r", 4)
            clust_colors = colors.to_rgba_array(clust_pal)

            connectome = plotting.plot_connectome(
                conn_matrix_symm,
                coords_MNI,
                edge_threshold=edge_threshold,
                node_color=clust_colors,
                edge_cmap=plotting.cm.black_blue_r)
            connectome.add_overlay(img=fdt_paths_MNI_loc,
                                   threshold=connectome_fdt_thresh,
                                   cmap=plotting.cm.cyan_copper_r)
            out_file_path = dir_path + '/structural_connectome_fig_' + NETWORK + '_' + str(
                ID) + '.png'
            plt.savefig(out_file_path)
            plt.close()

            from pynets import plotting as pynplot
            NETWORK = NETWORK + '_structural'
            pynplot.plot_connectogram(conn_matrix, conn_model, atlas_name,
                                      dir_path, ID, NETWORK, label_names)

        if NETWORK != None:
            est_path = dir_path + '/' + ID + '_' + NETWORK + '_structural_est.txt'
        else:
            est_path = dir_path + '/' + ID + '_structural_est.txt'
        try:
            np.savetxt(est_path, conn_matrix_symm, delimiter='\t')
        except RuntimeError:
            print('Diffusion network connectome failed!')
    return (est_path)
コード例 #6
0
ファイル: plotting.py プロジェクト: kaczmarj/PyNets
def structural_plotting(conn_matrix, conn_matrix_symm, label_names, atlas_select, ID, bedpostx_dir, network, parc, plot_switch, coords):  
    import nipype.interfaces.fsl as fsl
    import nipype.pipeline.engine as pe
    import matplotlib.pyplot as plt
    import seaborn as sns
    from pynets import plotting as pynplot
    from matplotlib import colors
    from nilearn import plotting as niplot

    edge_threshold = 0.90
    connectome_fdt_thresh = 1000
    
    ####Auto-set INPUTS####
    try:
        FSLDIR = os.environ['FSLDIR']
    except NameError:
        print('FSLDIR environment variable not set!')
    nodif_brain_mask_path = bedpostx_dir + '/nodif_brain_mask.nii.gz'
    input_MNI = FSLDIR + '/data/standard/MNI152_T1_1mm_brain.nii.gz'
    if network:
        probtrackx_output_dir_path = bedpostx_dir + '/probtrackx_' + network
    else:
        probtrackx_output_dir_path = bedpostx_dir + '/probtrackx_Whole_brain'
    dir_path = os.path.dirname(bedpostx_dir)
    ####Auto-set INPUTS####
    
    if plot_switch == True:
        plt.figure(figsize=(8, 8))
        plt.imshow(conn_matrix, interpolation="nearest", vmax=1, vmin=-1, cmap=plt.cm.RdBu_r)
        plt.xticks(range(len(label_names)), label_names, size='xx-small', rotation=90)
        plt.yticks(range(len(label_names)), label_names, size='xx-small')
        plt_title = atlas_select + ' Structural Connectivity of: ' + str(ID)
        plt.title(plt_title)
        plt.grid(False)
        plt.gcf().subplots_adjust(left=0.8)

        out_path_fig=dir_path + '/structural_adj_mat_' + str(ID) + '.png'
        plt.savefig(out_path_fig)
        plt.close()

        ##Prepare glass brain figure
        fdt_paths_loc = probtrackx_output_dir_path + '/fdt_paths.nii.gz'

        ##Create transform matrix between diff and MNI using FLIRT
        flirt = pe.Node(interface=fsl.FLIRT(cost_func='mutualinfo'),name='coregister')
        flirt.inputs.reference = input_MNI
        flirt.inputs.in_file = nodif_brain_mask_path
        flirt.inputs.out_matrix_file = bedpostx_dir + '/xfms/diff2MNI.mat'
        flirt.run()

        ##Apply transform between diff and MNI using FLIRT
        flirt = pe.Node(interface=fsl.FLIRT(cost_func='mutualinfo'),name='coregister')
        flirt.inputs.reference = input_MNI
        flirt.inputs.in_file = nodif_brain_mask_path
        flirt.inputs.apply_xfm = True
        flirt.inputs.in_matrix_file = bedpostx_dir + '/xfms/diff2MNI.mat'
        flirt.inputs.out_file = bedpostx_dir + '/xfms/diff2MNI_affine.nii.gz'
        flirt.run()

        flirt = pe.Node(interface=fsl.FLIRT(cost_func='mutualinfo'),name='coregister')
        flirt.inputs.reference = input_MNI
        flirt.inputs.in_file = fdt_paths_loc
        out_file_MNI = fdt_paths_loc.split('.nii')[0] + '_MNI.nii.gz'
        flirt.inputs.out_file = out_file_MNI
        flirt.inputs.apply_xfm = True
        flirt.inputs.in_matrix_file = bedpostx_dir + '/xfms/diff2MNI.mat'
        flirt.run()

        fdt_paths_MNI_loc = probtrackx_output_dir_path + '/fdt_paths_MNI.nii.gz'

        colors.Normalize(vmin=-1, vmax=1)
        clust_pal = sns.color_palette("Blues_r", 4)
        clust_colors = colors.to_rgba_array(clust_pal)

        ##Plotting with glass brain
        connectome = niplot.plot_connectome(conn_matrix_symm, coords, edge_threshold=edge_threshold, node_color=clust_colors, edge_cmap=niplot.cm.black_blue_r)
        connectome.add_overlay(img=fdt_paths_MNI_loc, threshold=connectome_fdt_thresh, cmap=niplot.cm.cyan_copper_r)
        out_file_path = dir_path + '/structural_connectome_fig_' + network + '_' + str(ID) + '.png'
        plt.savefig(out_file_path)
        plt.close()

        network = network + '_structural'
        conn_model = 'struct'
        pynplot.plot_connectogram(conn_matrix, conn_model, atlas_select, dir_path, ID, network, label_names)
    else:
        pass
    return
コード例 #7
0
def wb_connectome_with_us_atlas_coords(input_file, ID, atlas_select, NETWORK,
                                       node_size, mask, thr, parlistfile,
                                       all_nets, conn_model, dens_thresh, conf,
                                       adapt_thresh, plot_switch,
                                       bedpostx_dir):
    nilearn_atlases = [
        'atlas_aal', 'atlas_craddock_2012', 'atlas_destrieux_2009'
    ]

    ##Input is nifti file
    func_file = input_file

    ##Test if atlas_select is a nilearn atlas
    if atlas_select in nilearn_atlases:
        try:
            parlistfile = getattr(datasets, 'fetch_%s' % atlas_select)().maps
            try:
                label_names = getattr(datasets,
                                      'fetch_%s' % atlas_select)().labels
            except:
                label_names = None
            try:
                networks_list = getattr(datasets,
                                        'fetch_%s' % atlas_select)().networks
            except:
                networks_list = None
        except:
            print(
                'PyNets is not ready for multi-scale atlases like BASC just yet!'
            )
            sys.exit()

    ##Fetch user-specified atlas coords
    [coords, atlas_name,
     par_max] = nodemaker.get_names_and_coords_of_parcels(parlistfile)
    atlas_select = atlas_name

    try:
        label_names
    except:

        label_names = np.arange(len(coords) +
                                1)[np.arange(len(coords) + 1) != 0].tolist()

    ##Get subject directory path
    dir_path = os.path.dirname(
        os.path.realpath(func_file)) + '/' + atlas_select
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

    ##Get coord membership dictionary if all_nets option triggered
    if all_nets != None:
        try:
            networks_list
        except:
            networks_list = None
        [membership,
         membership_plotting] = nodemaker.get_mem_dict(func_file, coords,
                                                       networks_list)

    ##Describe user atlas coords
    print('\n' + atlas_name + ' comes with {0} '.format(par_max) + 'parcels' +
          '\n')
    print('\n' + 'Stacked atlas coordinates in array of shape {0}.'.format(
        coords.shape) + '\n')

    ##Mask coordinates
    if mask is not None:
        [coords, label_names] = nodemaker.coord_masker(mask, coords,
                                                       label_names)

    ##Save coords and label_names to pickles
    coord_path = dir_path + '/coords_wb_' + str(thr) + '.pkl'
    with open(coord_path, 'wb') as f:
        pickle.dump(coords, f)

    labels_path = dir_path + '/labelnames_wb_' + str(thr) + '.pkl'
    with open(labels_path, 'wb') as f:
        pickle.dump(label_names, f)

    if bedpostx_dir is not None:
        from pynets.diffconnectometry import run_struct_mapping
        FSLDIR = os.environ['FSLDIR']
        try:
            FSLDIR
        except NameError:
            print('FSLDIR environment variable not set!')
        est_path2 = run_struct_mapping(FSLDIR, ID, bedpostx_dir, dir_path,
                                       NETWORK, coords, node_size)

    ##extract time series from whole brain parcellaions:
    parcellation = nib.load(parlistfile)
    parcel_masker = input_data.NiftiLabelsMasker(labels_img=parcellation,
                                                 background_label=0,
                                                 memory='nilearn_cache',
                                                 memory_level=5,
                                                 standardize=True)
    ts_within_parcels = parcel_masker.fit_transform(func_file, confounds=conf)
    print('\n' +
          'Time series has {0} samples'.format(ts_within_parcels.shape[0]) +
          '\n')

    ##Save time series as txt file
    out_path_ts = dir_path + '/' + ID + '_whole_brain_ts_within_parcels.txt'
    np.savetxt(out_path_ts, ts_within_parcels)

    ##Fit connectivity model
    if adapt_thresh is not False:
        if os.path.isfile(est_path2) == True:
            [conn_matrix, est_path, edge_threshold,
             thr] = thresholding.adaptive_thresholding(ts_within_parcels,
                                                       conn_model, NETWORK, ID,
                                                       est_path2, dir_path)
        else:
            print('No structural mx found! Exiting...')
            sys.exit(0)
    elif dens_thresh is None:
        edge_threshold = str(float(thr) * 100) + '%'
        [conn_matrix,
         est_path] = graphestimation.get_conn_matrix(ts_within_parcels,
                                                     conn_model, NETWORK, ID,
                                                     dir_path, thr)
        conn_matrix = thresholding.threshold_proportional(
            conn_matrix, float(thr), dir_path)
        conn_matrix = thresholding.normalize(conn_matrix)
    elif dens_thresh is not None:
        [conn_matrix, est_path, edge_threshold,
         thr] = thresholding.density_thresholding(ts_within_parcels,
                                                  conn_model, NETWORK, ID,
                                                  dens_thresh, dir_path)

    if plot_switch == True:
        ##Plot connectogram
        plotting.plot_connectogram(conn_matrix, conn_model, atlas_name,
                                   dir_path, ID, NETWORK, label_names)

        ##Plot adj. matrix based on determined inputs
        atlast_graph_title = plotting.plot_conn_mat(conn_matrix, conn_model,
                                                    atlas_name, dir_path, ID,
                                                    NETWORK, label_names, mask)

        ##Plot connectome viz for all Yeo networks
        if all_nets != False:
            plotting.plot_membership(membership_plotting, conn_matrix,
                                     conn_model, coords, edge_threshold,
                                     atlas_name, dir_path)
        else:
            out_path_fig = dir_path + '/' + ID + '_connectome_viz.png'
            niplot.plot_connectome(conn_matrix,
                                   coords,
                                   title=atlast_graph_title,
                                   edge_threshold=edge_threshold,
                                   node_size=20,
                                   colorbar=True,
                                   output_file=out_path_fig)
    return est_path, thr
コード例 #8
0
def network_connectome(input_file, ID, atlas_select, NETWORK, node_size, mask,
                       thr, parlistfile, all_nets, conn_model, dens_thresh,
                       conf, adapt_thresh, plot_switch, bedpostx_dir):
    nilearn_atlases = [
        'atlas_aal', 'atlas_craddock_2012', 'atlas_destrieux_2009'
    ]

    ##Input is nifti file
    func_file = input_file

    ##Test if atlas_select is a nilearn atlas
    if atlas_select in nilearn_atlases:
        atlas = getattr(datasets, 'fetch_%s' % atlas_select)()
        try:
            parlistfile = atlas.maps
            try:
                label_names = atlas.labels
            except:
                label_names = None
            try:
                networks_list = atlas.networks
            except:
                networks_list = None
        except RuntimeError:
            print('Error, atlas fetching failed.')
            sys.exit()

    if parlistfile == None and atlas_select not in nilearn_atlases:
        ##Fetch nilearn atlas coords
        [coords, atlas_name, networks_list,
         label_names] = nodemaker.fetch_nilearn_atlas_coords(atlas_select)

        if atlas_name == 'Power 2011 atlas':
            ##Reference RSN list
            import pkgutil
            import io
            network_coords_ref = NETWORK + '_coords.csv'
            atlas_coords = pkgutil.get_data("pynets",
                                            "rsnrefs/" + network_coords_ref)
            df = pd.read_csv(io.BytesIO(atlas_coords)).ix[:, 0:4]
            i = 1
            net_coords = []
            ix_labels = []
            for i in range(len(df)):
                #print("ROI Reference #: " + str(i))
                x = int(df.ix[i, 1])
                y = int(df.ix[i, 2])
                z = int(df.ix[i, 3])
                #print("X:" + str(x) + " Y:" + str(y) + " Z:" + str(z))
                net_coords.append((x, y, z))
                ix_labels.append(i)
                i = i + 1
                #print(net_coords)
                label_names = ix_labels
        elif atlas_name == 'Dosenbach 2010 atlas':
            coords = list(tuple(x) for x in coords)

            ##Get coord membership dictionary
            [membership, membership_plotting
             ] = nodemaker.get_mem_dict(func_file, coords, networks_list)

            ##Convert to membership dataframe
            mem_df = membership.to_frame().reset_index()

            nets_avail = list(set(list(mem_df['index'])))
            ##Get network name equivalents
            if NETWORK == 'DMN':
                NETWORK = 'default'
            elif NETWORK == 'FPTC':
                NETWORK = 'fronto-parietal'
            elif NETWORK == 'CON':
                NETWORK = 'cingulo-opercular'
            elif NETWORK not in nets_avail:
                print('Error: ' + NETWORK + ' not available with this atlas!')
                sys.exit()

            ##Get coords for network-of-interest
            mem_df.loc[mem_df['index'] == NETWORK]
            net_coords = mem_df.loc[mem_df['index'] == NETWORK][[0]].values[:,
                                                                            0]
            net_coords = list(tuple(x) for x in net_coords)
            ix_labels = mem_df.loc[mem_df['index'] == NETWORK].index.values
            ####Add code for any special RSN reference lists for the nilearn atlases here#####
            ##If labels_names are not indices and NETWORK is specified, sub-list label names

        if label_names != ix_labels:
            try:
                label_names = label_names.tolist()
            except:
                pass
            label_names = [label_names[i] for i in ix_labels]

        ##Get subject directory path
        dir_path = os.path.dirname(
            os.path.realpath(func_file)) + '/' + atlas_select
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)

        ##If masking, remove those coords that fall outside of the mask
        if mask != None:
            [net_coords,
             label_names] = nodemaker.coord_masker(mask, net_coords,
                                                   label_names)

        ##Save coords and label_names to pickles
        coord_path = dir_path + '/coords_' + NETWORK + '_' + str(thr) + '.pkl'
        with open(coord_path, 'wb') as f:
            pickle.dump(net_coords, f)

        labels_path = dir_path + '/labelnames_' + NETWORK + '_' + str(
            thr) + '.pkl'
        with open(labels_path, 'wb') as f:
            pickle.dump(label_names, f)

        if bedpostx_dir is not None:
            from pynets.diffconnectometry import run_struct_mapping
            FSLDIR = os.environ['FSLDIR']
            try:
                FSLDIR
            except NameError:
                print('FSLDIR environment variable not set!')
            est_path2 = run_struct_mapping(FSLDIR, ID, bedpostx_dir, dir_path,
                                           NETWORK, net_coords, node_size)

    else:
        ##Fetch user-specified atlas coords
        [coords_all, atlas_name,
         par_max] = nodemaker.get_names_and_coords_of_parcels(parlistfile)
        coords = list(tuple(x) for x in coords_all)

        ##Get subject directory path
        dir_path = os.path.dirname(
            os.path.realpath(func_file)) + '/' + atlas_name
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)

        ##Get coord membership dictionary
        try:
            networks_list
        except:
            networks_list = None
        [membership,
         membership_plotting] = nodemaker.get_mem_dict(func_file, coords,
                                                       networks_list)

        ##Convert to membership dataframe
        mem_df = membership.to_frame().reset_index()

        ##Get coords for network-of-interest
        mem_df.loc[mem_df['index'] == NETWORK]
        net_coords = mem_df.loc[mem_df['index'] == NETWORK][[0]].values[:, 0]
        net_coords = list(tuple(x) for x in net_coords)
        ix_labels = mem_df.loc[mem_df['index'] == NETWORK].index.values
        try:
            label_names = [label_names[i] for i in ix_labels]
        except:
            label_names = ix_labels

        if mask != None:
            [net_coords,
             label_names] = nodemaker.coord_masker(mask, net_coords,
                                                   label_names)

        ##Save coords and label_names to pickles
        coord_path = dir_path + '/coords_' + NETWORK + '_' + str(thr) + '.pkl'
        with open(coord_path, 'wb') as f:
            pickle.dump(net_coords, f)

        labels_path = dir_path + '/labelnames_' + NETWORK + '_' + str(
            thr) + '.pkl'
        with open(labels_path, 'wb') as f:
            pickle.dump(label_names, f)

        if bedpostx_dir is not None:
            from pynets.diffconnectometry import run_struct_mapping
            est_path2 = run_struct_mapping(FSLDIR, ID, bedpostx_dir, dir_path,
                                           NETWORK, net_coords, node_size)

        ##Generate network parcels image (through refinement, this could be used
        ##in place of the 3 lines above)
        #net_parcels_img_path = gen_network_parcels(parlistfile, NETWORK, labels)
        #parcellation = nib.load(net_parcels_img_path)
        #parcel_masker = input_data.NiftiLabelsMasker(labels_img=parcellation, background_label=0, memory='nilearn_cache', memory_level=5, standardize=True)
        #ts_within_parcels = parcel_masker.fit_transform(func_file)
        #net_ts = ts_within_parcels

    ##Grow ROIs
    masker = input_data.NiftiSpheresMasker(seeds=net_coords,
                                           radius=float(node_size),
                                           allow_overlap=True,
                                           memory_level=5,
                                           memory='nilearn_cache',
                                           verbose=2,
                                           standardize=True)
    ts_within_spheres = masker.fit_transform(func_file, confounds=conf)
    net_ts = ts_within_spheres

    ##Save time series as txt file
    out_path_ts = dir_path + '/' + ID + '_' + NETWORK + '_net_ts.txt'
    np.savetxt(out_path_ts, net_ts)

    ##Fit connectivity model
    if adapt_thresh is not False:
        if os.path.isfile(est_path2) == True:
            [conn_matrix, est_path, edge_threshold,
             thr] = thresholding.adaptive_thresholding(ts_within_spheres,
                                                       conn_model, NETWORK, ID,
                                                       est_path2, dir_path)
        else:
            print('No structural mx found! Exiting...')
            sys.exit(0)
    elif dens_thresh is None:
        edge_threshold = str(float(thr) * 100) + '%'
        [conn_matrix,
         est_path] = graphestimation.get_conn_matrix(ts_within_spheres,
                                                     conn_model, NETWORK, ID,
                                                     dir_path, thr)
        conn_matrix = thresholding.threshold_proportional(
            conn_matrix, float(thr), dir_path)
        conn_matrix = thresholding.normalize(conn_matrix)
    elif dens_thresh is not None:
        [conn_matrix, est_path, edge_threshold,
         thr] = thresholding.density_thresholding(ts_within_spheres,
                                                  conn_model, NETWORK, ID,
                                                  dens_thresh, dir_path)

    if plot_switch == True:
        ##Plot connectogram
        plotting.plot_connectogram(conn_matrix, conn_model, atlas_name,
                                   dir_path, ID, NETWORK, label_names)

        ##Plot adj. matrix based on determined inputs
        plotting.plot_conn_mat(conn_matrix, conn_model, atlas_name, dir_path,
                               ID, NETWORK, label_names, mask)

        ##Plot network time-series
        plotting.plot_timeseries(net_ts, NETWORK, ID, dir_path, atlas_name,
                                 label_names)

        ##Plot connectome viz for specific Yeo networks
        title = "Connectivity Projected on the " + NETWORK
        out_path_fig = dir_path + '/' + ID + '_' + NETWORK + '_connectome_plot.png'
        niplot.plot_connectome(conn_matrix,
                               net_coords,
                               edge_threshold=edge_threshold,
                               title=title,
                               display_mode='lyrz',
                               output_file=out_path_fig)
    return est_path, thr
コード例 #9
0
def wb_connectome_with_nl_atlas_coords(input_file, ID, atlas_select, NETWORK,
                                       node_size, mask, thr, all_nets,
                                       conn_model, dens_thresh, conf,
                                       adapt_thresh, plot_switch,
                                       bedpostx_dir):
    nilearn_atlases = [
        'atlas_aal', 'atlas_craddock_2012', 'atlas_destrieux_2009'
    ]

    ##Input is nifti file
    func_file = input_file

    ##Fetch nilearn atlas coords
    [coords, atlas_name, networks_list,
     label_names] = nodemaker.fetch_nilearn_atlas_coords(atlas_select)

    ##Get subject directory path
    dir_path = os.path.dirname(
        os.path.realpath(func_file)) + '/' + atlas_select
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

    ##Get coord membership dictionary if all_nets option triggered
    if all_nets != False:
        try:
            networks_list
        except:
            networks_list = None
        [membership,
         membership_plotting] = nodemaker.get_mem_dict(func_file, coords,
                                                       networks_list)

    ##Mask coordinates
    if mask is not None:
        [coords, label_names] = nodemaker.coord_masker(mask, coords,
                                                       label_names)

    ##Save coords and label_names to pickles
    coord_path = dir_path + '/coords_wb_' + str(thr) + '.pkl'
    with open(coord_path, 'wb') as f:
        pickle.dump(coords, f)

    labels_path = dir_path + '/labelnames_wb_' + str(thr) + '.pkl'
    with open(labels_path, 'wb') as f:
        pickle.dump(label_names, f)

    if bedpostx_dir is not None:
        from pynets.diffconnectometry import run_struct_mapping
        FSLDIR = os.environ['FSLDIR']
        try:
            FSLDIR
        except NameError:
            print('FSLDIR environment variable not set!')
        est_path2 = run_struct_mapping(FSLDIR, ID, bedpostx_dir, dir_path,
                                       NETWORK, coords, node_size)

    ##Extract within-spheres time-series from funct file
    spheres_masker = input_data.NiftiSpheresMasker(seeds=coords,
                                                   radius=float(node_size),
                                                   memory='nilearn_cache',
                                                   memory_level=5,
                                                   verbose=2,
                                                   standardize=True)
    ts_within_spheres = spheres_masker.fit_transform(func_file, confounds=conf)
    print('\n' +
          'Time series has {0} samples'.format(ts_within_spheres.shape[0]) +
          '\n')

    ##Save time series as txt file
    out_path_ts = dir_path + '/' + ID + '_whole_brain_ts_within_spheres.txt'
    np.savetxt(out_path_ts, ts_within_spheres)

    ##Fit connectivity model
    if adapt_thresh is not False:
        if os.path.isfile(est_path2) == True:
            [conn_matrix, est_path, edge_threshold,
             thr] = thresholding.adaptive_thresholding(ts_within_spheres,
                                                       conn_model, NETWORK, ID,
                                                       est_path2, dir_path)
        else:
            print('No structural mx found! Exiting...')
            sys.exit(0)
    elif dens_thresh is None:
        edge_threshold = str(float(thr) * 100) + '%'
        [conn_matrix,
         est_path] = graphestimation.get_conn_matrix(ts_within_spheres,
                                                     conn_model, NETWORK, ID,
                                                     dir_path, thr)
        conn_matrix = thresholding.threshold_proportional(
            conn_matrix, float(thr), dir_path)
        conn_matrix = thresholding.normalize(conn_matrix)
    elif dens_thresh is not None:
        [conn_matrix, est_path, edge_threshold,
         thr] = thresholding.density_thresholding(ts_within_spheres,
                                                  conn_model, NETWORK, ID,
                                                  dens_thresh, dir_path)

    if plot_switch == True:
        ##Plot connectogram
        plotting.plot_connectogram(conn_matrix, conn_model, atlas_name,
                                   dir_path, ID, NETWORK, label_names)

        ##Plot adj. matrix based on determined inputs
        plotting.plot_conn_mat(conn_matrix, conn_model, atlas_name, dir_path,
                               ID, NETWORK, label_names, mask)

        ##Plot connectome viz for all Yeo networks
        if all_nets != False:
            plotting.plot_membership(membership_plotting, conn_matrix,
                                     conn_model, coords, edge_threshold,
                                     atlas_name, dir_path)
        else:
            out_path_fig = dir_path + '/' + ID + '_' + atlas_name + '_connectome_viz.png'
            niplot.plot_connectome(conn_matrix,
                                   coords,
                                   title=atlas_name,
                                   edge_threshold=edge_threshold,
                                   node_size=20,
                                   colorbar=True,
                                   output_file=out_path_fig)
    return est_path, thr
コード例 #10
0
def plot_all(conn_matrix, conn_model, atlas_select, dir_path, ID, network,
             label_names, mask, coords, thr, node_size, edge_threshold):
    import matplotlib
    matplotlib.use('Agg')
    from matplotlib import pyplot as plt
    from nilearn import plotting as niplot
    import pkg_resources
    import networkx as nx
    from pynets import plotting
    from pynets.netstats import most_important
    try:
        import cPickle as pickle
    except ImportError:
        import _pickle as pickle

    pruning = True
    dpi_resolution = 500
    G_pre = nx.from_numpy_matrix(conn_matrix)
    if pruning == True:
        [G, pruned_nodes, pruned_edges] = most_important(G_pre)
    else:
        G = G_pre
    conn_matrix = nx.to_numpy_array(G)

    pruned_nodes.sort(reverse=True)
    for j in pruned_nodes:
        del label_names[label_names.index(label_names[j])]
        del coords[coords.index(coords[j])]

    pruned_edges.sort(reverse=True)
    for j in pruned_edges:
        del label_names[label_names.index(label_names[j])]
        del coords[coords.index(coords[j])]

    # Plot connectogram
    if len(conn_matrix) > 20:
        try:
            plotting.plot_connectogram(conn_matrix, conn_model, atlas_select,
                                       dir_path, ID, network, label_names)
        except RuntimeError:
            print('\n\n\nError: Connectogram plotting failed!')
    else:
        print(
            'Error: Cannot plot connectogram for graphs smaller than 20 x 20!')

    # Plot adj. matrix based on determined inputs
    plotting.plot_conn_mat_func(conn_matrix, conn_model, atlas_select,
                                dir_path, ID, network, label_names, mask, thr,
                                node_size)

    # Plot connectome
    if mask:
        if network:
            out_path_fig = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
                dir_path, '/', ID, '_', str(atlas_select), '_',
                str(conn_model), '_', str(
                    os.path.basename(mask).split('.')[0]), '_', str(network),
                '_', str(thr), '_', str(node_size), '_func_glass_viz.png')
        else:
            out_path_fig = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
                dir_path, '/', ID, '_', str(atlas_select), '_',
                str(conn_model), '_', str(
                    os.path.basename(mask).split('.')[0]), '_', str(thr), '_',
                str(node_size), '_func_glass_viz.png')
        # Save coords to pickle
        coord_path = "%s%s%s%s" % (dir_path, '/coords_',
                                   os.path.basename(mask).split('.')[0],
                                   '_plotting.pkl')
        with open(coord_path, 'wb') as f:
            pickle.dump(coords, f, protocol=2)
        net_parcels_map_nifti = None
        # Save labels to pickle
        labels_path = "%s%s%s%s" % (dir_path, '/labelnames_',
                                    os.path.basename(mask).split('.')[0],
                                    '_plotting.pkl')
        with open(labels_path, 'wb') as f:
            pickle.dump(label_names, f, protocol=2)
    else:
        if network:
            out_path_fig = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
                dir_path, '/', ID, '_', str(atlas_select), '_',
                str(conn_model), '_', str(network), '_', str(thr), '_',
                str(node_size), '_func_glass_viz.png')
        else:
            out_path_fig = "%s%s%s%s%s%s%s%s%s%s%s%s" % (
                dir_path, '/', ID, '_',
                str(atlas_select), '_', str(conn_model), '_', str(thr), '_',
                str(node_size), '_func_glass_viz.png')
        # Save coords to pickle
        coord_path = "%s%s" % (dir_path, '/coords_plotting.pkl')
        with open(coord_path, 'wb') as f:
            pickle.dump(coords, f, protocol=2)
        # Save labels to pickle
        labels_path = "%s%s" % (dir_path, '/labelnames_plotting.pkl')
        with open(labels_path, 'wb') as f:
            pickle.dump(label_names, f, protocol=2)
    #niplot.plot_connectome(conn_matrix, coords, edge_threshold=edge_threshold, node_size=20, colorbar=True, output_file=out_path_fig)
    ch2better_loc = pkg_resources.resource_filename(
        "pynets", "templates/ch2better.nii.gz")
    connectome = niplot.plot_connectome(np.zeros(shape=(1, 1)), [(0, 0, 0)],
                                        node_size=0.0001)
    connectome.add_overlay(ch2better_loc, alpha=0.4, cmap=plt.cm.gray)
    [z_min, z_max] = -np.abs(conn_matrix).max(), np.abs(conn_matrix).max()
    connectome.add_graph(conn_matrix,
                         coords,
                         edge_threshold=edge_threshold,
                         edge_cmap='Greens',
                         edge_vmax=z_max,
                         edge_vmin=z_min,
                         node_size=4)
    connectome.savefig(out_path_fig, dpi=dpi_resolution)
    #connectome.savefig(out_path_fig, dpi=dpi_resolution, facecolor ='k', edgecolor ='k')
    return