示例#1
0
文件: interfaces.py 项目: dPys/PyNets
    def normalize_graph(self):
        import graspologic.utils as gu

        # By maximum edge weight
        if self.norm == 1:
            self.in_mat = thresholding.normalize(np.nan_to_num(self.in_mat))
        # Apply log10
        elif self.norm == 2:
            self.in_mat = np.log10(np.nan_to_num(self.in_mat))
        # Apply PTR simple-nonzero
        elif self.norm == 3:
            self.in_mat = gu.ptr.pass_to_ranks(np.nan_to_num(self.in_mat),
                                               method="simple-nonzero")
        # Apply PTR simple-all
        elif self.norm == 4:
            self.in_mat = gu.ptr.pass_to_ranks(np.nan_to_num(self.in_mat),
                                               method="simple-all")
        # Apply PTR zero-boost
        elif self.norm == 5:
            self.in_mat = gu.ptr.pass_to_ranks(np.nan_to_num(self.in_mat),
                                               method="zero-boost")
        # Apply standardization [0, 1]
        elif self.norm == 6:
            self.in_mat = thresholding.standardize(np.nan_to_num(self.in_mat))
        elif self.norm == 7:
            # Get hyperbolic tangent (i.e. fischer r-to-z transform) of matrix
            # if non-covariance
            self.in_mat = np.arctanh(self.in_mat)
        else:
            pass

        self.in_mat = thresholding.autofix(self.in_mat)
        self.G = nx.from_numpy_array(self.in_mat)

        return self.G
示例#2
0
def plot_conn_mat(conn_matrix,
                  labels,
                  out_path_fig,
                  cmap,
                  binarized=False,
                  dpi_resolution=300):
    """
    Plot a connectivity matrix.

    Parameters
    ----------
    conn_matrix : array
        NxN matrix.
    labels : list
        List of string labels corresponding to ROI nodes.
    out_path_fig : str
        File path to save the connectivity matrix image as a .png figure.
    """
    import matplotlib

    matplotlib.use("agg")
    from matplotlib import pyplot as plt
    from nilearn.plotting import plot_matrix
    from pynets.core import thresholding
    import matplotlib.ticker as mticker

    conn_matrix = thresholding.standardize(conn_matrix)
    conn_matrix_bin = thresholding.binarize(conn_matrix)
    conn_matrix_plt = np.nan_to_num(np.multiply(conn_matrix, conn_matrix_bin))

    try:
        plot_matrix(
            conn_matrix_plt,
            figure=(10, 10),
            labels=labels,
            vmax=np.percentile(conn_matrix_plt[conn_matrix_plt > 0], 95),
            vmin=np.min(conn_matrix_plt) - 0.000001,
            reorder="average",
            auto_fit=True,
            grid=False,
            colorbar=False,
            cmap=cmap,
        )
    except RuntimeWarning:
        print("Connectivity matrix too sparse for plotting...")

    if len(labels) > 40:
        tick_interval = int(np.around(len(labels) / 40))
    else:
        tick_interval = int(np.around(len(labels)))
    plt.axes().yaxis.set_major_locator(mticker.MultipleLocator(tick_interval))
    plt.axes().xaxis.set_major_locator(mticker.MultipleLocator(tick_interval))
    plt.savefig(out_path_fig, dpi=dpi_resolution)
    plt.close()
    return
示例#3
0
def plot_conn_mat(conn_matrix, labels, out_path_fig, cmap, dpi_resolution=300):
    """
    Plot a connectivity matrix.

    Parameters
    ----------
    conn_matrix : array
        NxN matrix.
    labels : list
        List of string labels corresponding to ROI nodes.
    out_path_fig : str
        File path to save the connectivity matrix image as a .png figure.
    """
    import matplotlib

    matplotlib.use("agg")
    from matplotlib import pyplot as plt
    from nilearn.plotting import plot_matrix
    from pynets.core import thresholding

    conn_matrix_bin = thresholding.binarize(conn_matrix)
    conn_matrix = thresholding.standardize(conn_matrix)
    conn_matrix_plt = np.nan_to_num(np.multiply(conn_matrix, conn_matrix_bin))

    try:
        plot_matrix(
            conn_matrix_plt,
            figure=(10, 10),
            labels=labels,
            vmax=np.abs(np.max(conn_matrix_plt)),
            vmin=-np.abs(np.max(conn_matrix_plt)),
            reorder="average",
            auto_fit=True,
            grid=False,
            colorbar=False,
            cmap=cmap,
        )
    except RuntimeWarning:
        print("Connectivity matrix too sparse for plotting...")
    plt.savefig(out_path_fig, dpi=dpi_resolution)
    plt.close()
    return
示例#4
0
 def test_standardize(x):
     w = thresholding.standardize(x)
     assert isinstance(w, np.ndarray)
示例#5
0
def compare_motifs(struct_mat, func_mat, name, bins=50, N=4):
    from pynets.stats.netmotifs import adaptivethresh
    from pynets.core.thresholding import standardize
    from scipy import spatial
    import pandas as pd
    from py3plex.core import multinet

    # Structural graph threshold window
    struct_mat = standardize(struct_mat)
    dims_struct = struct_mat.shape[0]
    struct_mat[range(dims_struct), range(dims_struct)] = 0
    tmin_struct = struct_mat.min()
    tmax_struct = struct_mat.max()
    threshes_struct = np.linspace(tmin_struct, tmax_struct, bins)

    # Functional graph threshold window
    func_mat = standardize(func_mat)
    dims_func = func_mat.shape[0]
    func_mat[range(dims_func), range(dims_func)] = 0
    tmin_func = func_mat.min()
    tmax_func = func_mat.max()
    threshes_func = np.linspace(tmin_func, tmax_func, bins)

    assert np.all(
        struct_mat == struct_mat.T), "Structural Matrix must be symmetric"
    assert np.all(
        func_mat == func_mat.T), "Functional Matrix must be symmetric"

    # list of
    mlib = ['1113', '1122', '1223', '2222', '2233', '3333']

    # Count motifs
    print("%s%s%s%s" % ('Mining ', N, '-node motifs: ', mlib))
    motif_dict = {}
    for thr_struct, thr_func in list(
            itertools.product(threshes_struct, threshes_func)):
        # Count
        at_struct = adaptivethresh(struct_mat, float(thr_struct), mlib, N)
        at_func = adaptivethresh(func_mat, float(thr_func), mlib, N)

        motif_dict["%s%s%s%s" % ('struct_', np.round(
            thr_struct, 4), '_func_', np.round(thr_func, 4))] = {}
        motif_dict["%s%s%s%s" % ('struct_', np.round(thr_struct, 4), '_func_',
                                 np.round(thr_func, 4))]['struct'] = at_struct
        motif_dict["%s%s%s%s" % ('struct_', np.round(
            thr_struct, 4), '_func_', np.round(thr_func, 4))]['func'] = at_func

        print("%s%s%s%s%s" %
              ('Layer 1 (structural) with absolute threshold of : ',
               thr_struct, ' yields ', np.sum(at_struct), ' total motifs'))
        print("%s%s%s%s%s" %
              ('Layer 2 (functional) with absolute threshold of : ', thr_func,
               ' yields ', np.sum(at_func), ' total motifs'))

    for k, v in list(motif_dict.items()):
        if np.sum(v['struct']) == 0 or np.sum(v['func']) == 0:
            del motif_dict[k]

    for k, v in list(motif_dict.items()):
        motif_dict[k]['dist'] = spatial.distance.cosine(v['struct'], v['func'])

    df = pd.DataFrame(motif_dict).T

    df['struct_func_3333'] = np.zeros(len(df))
    df['struct_func_2233'] = np.zeros(len(df))
    df['struct_func_2222'] = np.zeros(len(df))
    df['struct_func_1223'] = np.zeros(len(df))
    df['struct_func_1122'] = np.zeros(len(df))
    df['struct_func_1113'] = np.zeros(len(df))
    df['struct_3333'] = np.zeros(len(df))
    df['func_3333'] = np.zeros(len(df))
    df['struct_2233'] = np.zeros(len(df))
    df['func_2233'] = np.zeros(len(df))
    df['struct_2222'] = np.zeros(len(df))
    df['func_2222'] = np.zeros(len(df))
    df['struct_1223'] = np.zeros(len(df))
    df['func_1223'] = np.zeros(len(df))
    df['struct_1122'] = np.zeros(len(df))
    df['func_1122'] = np.zeros(len(df))
    df['struct_1113'] = np.zeros(len(df))
    df['func_1113'] = np.zeros(len(df))

    for idx in range(len(df)):
        df.set_value(df.index[idx], 'struct_3333', df['struct'][idx][-1])
        df.set_value(df.index[idx], 'func_3333', df['func'][idx][-1])

        df.set_value(df.index[idx], 'struct_2233', df['struct'][idx][-2])
        df.set_value(df.index[idx], 'func_2233', df['func'][idx][-2])

        df.set_value(df.index[idx], 'struct_2222', df['struct'][idx][-3])
        df.set_value(df.index[idx], 'func_2222', df['func'][idx][-3])

        df.set_value(df.index[idx], 'struct_1223', df['struct'][idx][-4])
        df.set_value(df.index[idx], 'func_1223', df['func'][idx][-4])

        df.set_value(df.index[idx], 'struct_1122', df['struct'][idx][-5])
        df.set_value(df.index[idx], 'func_1122', df['func'][idx][-5])

        df.set_value(df.index[idx], 'struct_1113', df['struct'][idx][-6])
        df.set_value(df.index[idx], 'func_1113', df['func'][idx][-6])

    df['struct_func_3333'] = np.abs(df['struct_3333'] - df['func_3333'])
    df['struct_func_2233'] = np.abs(df['struct_2233'] - df['func_2233'])
    df['struct_func_2222'] = np.abs(df['struct_2222'] - df['func_2222'])
    df['struct_func_1223'] = np.abs(df['struct_1223'] - df['func_1223'])
    df['struct_func_1122'] = np.abs(df['struct_1122'] - df['func_1122'])
    df['struct_func_1113'] = np.abs(df['struct_1113'] - df['func_1113'])

    df = df[(df.struct_3333 != 0) & (df.func_3333 != 0) & (df.struct_2233 != 0)
            & (df.func_2233 != 0) & (df.struct_2222 != 0) & (df.func_2222 != 0)
            & (df.struct_1223 != 0) & (df.func_1223 != 0) &
            (df.struct_1122 != 0) & (df.func_1122 != 0) & (df.struct_1113 != 0)
            & (df.func_1113 != 0)]

    df = df.sort_values(by=[
        'dist', 'struct_func_3333', 'struct_func_2233', 'struct_func_2222',
        'struct_func_1223', 'struct_func_1122', 'struct_func_1113',
        'struct_3333', 'func_3333', 'struct_2233', 'func_2233', 'struct_2222',
        'func_2222', 'struct_1223', 'func_1223', 'struct_1122', 'func_1122',
        'struct_1113', 'func_1113'
    ],
                        ascending=[
                            True, True, True, True, True, True, True, False,
                            False, False, False, False, False, False, False,
                            False, False, False, False
                        ])

    # Take the top 25th percentile
    df = df[df['dist'] < df['dist'].quantile(0.25)]
    best_threshes = []
    best_mats = []
    #best_graphs = []
    best_multigraphs = []
    for str in list(df.index):
        func_mat_tmp = func_mat.copy()
        struct_mat_tmp = struct_mat.copy()
        struct_thr = float(str.split('_')[1])
        func_thr = float(str.split('_')[3])
        best_threshes.append((struct_thr, func_thr))

        func_mat_tmp[func_mat_tmp < func_thr] = 0
        struct_mat_tmp[struct_mat_tmp < struct_thr] = 0
        best_mats.append((func_mat_tmp, struct_mat_tmp))

        G = build_nx_multigraph(func_mat, struct_mat, str)
        #best_graphs.append(G)

        B = multinet.multi_layer_network(network_type="multiplex",
                                         directed=False)
        B.add_edges([[x, 1, y, 2, z] for x, y, z in list(G.edges)],
                    input_type="list")
        best_multigraphs.append(B)

    mg_dict = dict(zip(best_threshes, best_multigraphs))

    return mg_dict
示例#6
0
def plot_community_conn_mat(conn_matrix,
                            labels,
                            out_path_fig_comm,
                            community_aff,
                            cmap,
                            dpi_resolution=300):
    """
    Plot a community-parcellated connectivity matrix.

    Parameters
    ----------
    conn_matrix : array
        NxN matrix.
    labels : list
        List of string labels corresponding to ROI nodes.
    out_path_fig_comm : str
        File path to save the community-parcellated connectivity matrix image
        as a .png figure.
    community_aff : array
        Community-affiliation vector.
    """
    import warnings
    warnings.filterwarnings("ignore")
    from matplotlib import pyplot as plt
    matplotlib.use("agg")
    import mplcyberpunk
    plt.style.use("cyberpunk")
    import matplotlib.patches as patches
    import matplotlib.ticker as mticker
    matplotlib.use("agg")
    from nilearn.plotting import plot_matrix
    from pynets.core import thresholding

    plt.style.use("cyberpunk")

    conn_matrix_bin = thresholding.binarize(conn_matrix)
    conn_matrix = thresholding.standardize(conn_matrix)
    conn_matrix_plt = np.nan_to_num(np.multiply(conn_matrix, conn_matrix_bin))

    sorting_array = sorted(range(len(community_aff)),
                           key=lambda k: community_aff[k])
    sorted_conn_matrix = conn_matrix[sorting_array, :]
    sorted_conn_matrix = sorted_conn_matrix[:, sorting_array]
    rois_num = sorted_conn_matrix.shape[0]
    if rois_num < 100:
        try:
            plot_matrix(
                conn_matrix_plt,
                figure=(10, 10),
                labels=labels,
                vmax=np.percentile(conn_matrix_plt[conn_matrix_plt > 0], 95),
                vmin=0,
                reorder=False,
                auto_fit=True,
                grid=False,
                colorbar=False,
                cmap=cmap,
            )
        except RuntimeWarning:
            print("Connectivity matrix too sparse for plotting...")
    else:
        try:
            plot_matrix(
                conn_matrix_plt,
                figure=(10, 10),
                vmax=np.abs(np.max(conn_matrix_plt)),
                vmin=0,
                auto_fit=True,
                grid=False,
                colorbar=False,
                cmap=cmap,
            )
        except RuntimeWarning:
            print("Connectivity matrix too sparse for plotting...")

    ax = plt.gca()
    total_size = 0
    for community in np.unique(community_aff):
        size = sum(sorted(community_aff) == community)
        ax.add_patch(
            patches.Rectangle(
                (total_size, total_size),
                size,
                size,
                fill=False,
                edgecolor="white",
                alpha=None,
                linewidth=1,
            ))
        total_size += size

    if len(labels) > 500:
        tick_interval = 5
    elif len(labels) > 100:
        tick_interval = 4
    elif len(labels) > 50:
        tick_interval = 2
    else:
        tick_interval = 1

    plt.axes().yaxis.set_major_locator(mticker.MultipleLocator(tick_interval))
    plt.axes().xaxis.set_major_locator(mticker.MultipleLocator(tick_interval))
    for param in ['figure.facecolor', 'axes.facecolor', 'savefig.facecolor']:
        plt.rcParams[param] = '#000000'
    plt.savefig(out_path_fig_comm, dpi=dpi_resolution)
    plt.close()
    return
示例#7
0
def plot_conn_mat(conn_matrix,
                  labels,
                  out_path_fig,
                  cmap,
                  binarized=False,
                  dpi_resolution=300):
    """
    Plot a connectivity matrix.

    Parameters
    ----------
    conn_matrix : array
        NxN matrix.
    labels : list
        List of string labels corresponding to ROI nodes.
    out_path_fig : str
        File path to save the connectivity matrix image as a .png figure.
    """
    import warnings
    warnings.filterwarnings("ignore")
    import matplotlib
    import mplcyberpunk
    from matplotlib import pyplot as plt
    matplotlib.use("agg")
    plt.style.use("cyberpunk")
    from matplotlib import pyplot as plt
    from nilearn.plotting import plot_matrix
    from pynets.core import thresholding
    import matplotlib.ticker as mticker

    conn_matrix = thresholding.standardize(conn_matrix)
    conn_matrix_bin = thresholding.binarize(conn_matrix)
    conn_matrix_plt = np.nan_to_num(np.multiply(conn_matrix, conn_matrix_bin))

    try:
        plot_matrix(
            conn_matrix_plt,
            figure=(10, 10),
            labels=labels,
            vmax=np.percentile(conn_matrix_plt[conn_matrix_plt > 0], 95),
            vmin=0,
            reorder="average",
            auto_fit=True,
            grid=False,
            colorbar=False,
            cmap=cmap,
        )
    except RuntimeWarning:
        print("Connectivity matrix too sparse for plotting...")

    if len(labels) > 500:
        tick_interval = 5
    elif len(labels) > 100:
        tick_interval = 4
    elif len(labels) > 50:
        tick_interval = 2
    else:
        tick_interval = 1

    plt.axes().yaxis.set_major_locator(mticker.MultipleLocator(tick_interval))
    plt.axes().xaxis.set_major_locator(mticker.MultipleLocator(tick_interval))
    for param in ['figure.facecolor', 'axes.facecolor', 'savefig.facecolor']:
        plt.rcParams[param] = '#000000'
    plt.savefig(out_path_fig, dpi=dpi_resolution)
    plt.close()
    return
示例#8
0
def plot_community_conn_mat(conn_matrix,
                            labels,
                            out_path_fig_comm,
                            community_aff,
                            cmap,
                            dpi_resolution=300):
    """
    Plot a community-parcellated connectivity matrix.

    Parameters
    ----------
    conn_matrix : array
        NxN matrix.
    labels : list
        List of string labels corresponding to ROI nodes.
    out_path_fig_comm : str
        File path to save the community-parcellated connectivity matrix image
        as a .png figure.
    community_aff : array
        Community-affiliation vector.
    """
    import warnings
    warnings.filterwarnings("ignore")
    import sys
    import pkg_resources
    import yaml
    import matplotlib
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    import matplotlib.ticker as mticker
    matplotlib.use("agg")
    from nilearn.plotting import plot_matrix
    from pynets.core import thresholding

    conn_matrix_bin = thresholding.binarize(conn_matrix)
    conn_matrix = thresholding.standardize(conn_matrix)
    conn_matrix_plt = np.nan_to_num(np.multiply(conn_matrix, conn_matrix_bin))

    try:
        with open(pkg_resources.resource_filename("pynets", "runconfig.yaml"),
                  "r") as stream:
            hardcoded_params = yaml.load(stream)
            try:
                labeling_atlas = \
                hardcoded_params["plotting"]["labeling_atlas"][0]
            except KeyError:
                print(
                    "ERROR: Plotting configuration not successfully extracted"
                    " from runconfig.yaml")
                sys.exit(0)
        stream.close()
        labels = [i[0][labeling_atlas] for i in labels]
    except BaseException:
        pass

    sorting_array = sorted(range(len(community_aff)),
                           key=lambda k: community_aff[k])
    sorted_conn_matrix = conn_matrix[sorting_array, :]
    sorted_conn_matrix = sorted_conn_matrix[:, sorting_array]
    rois_num = sorted_conn_matrix.shape[0]
    if rois_num < 100:
        try:
            plot_matrix(
                conn_matrix_plt,
                figure=(10, 10),
                labels=labels,
                vmax=np.percentile(conn_matrix_plt[conn_matrix_plt > 0], 95),
                vmin=0,
                reorder=False,
                auto_fit=True,
                grid=False,
                colorbar=False,
                cmap=cmap,
            )
        except RuntimeWarning:
            print("Connectivity matrix too sparse for plotting...")
    else:
        try:
            plot_matrix(
                conn_matrix_plt,
                figure=(10, 10),
                vmax=np.abs(np.max(conn_matrix_plt)),
                vmin=0,
                auto_fit=True,
                grid=False,
                colorbar=False,
                cmap=cmap,
            )
        except RuntimeWarning:
            print("Connectivity matrix too sparse for plotting...")

    ax = plt.gca()
    total_size = 0
    for community in np.unique(community_aff):
        size = sum(sorted(community_aff) == community)
        ax.add_patch(
            patches.Rectangle(
                (total_size, total_size),
                size,
                size,
                fill=False,
                edgecolor="black",
                alpha=None,
                linewidth=1,
            ))
        total_size += size

    tick_interval = int(np.around(len(labels))) / 20
    plt.axes().yaxis.set_major_locator(mticker.MultipleLocator(tick_interval))
    plt.axes().xaxis.set_major_locator(mticker.MultipleLocator(tick_interval))
    plt.savefig(out_path_fig_comm, dpi=dpi_resolution)
    plt.close()
    return
示例#9
0
def plot_conn_mat(conn_matrix,
                  labels,
                  out_path_fig,
                  cmap,
                  binarized=False,
                  dpi_resolution=300):
    """
    Plot a connectivity matrix.

    Parameters
    ----------
    conn_matrix : array
        NxN matrix.
    labels : list
        List of string labels corresponding to ROI nodes.
    out_path_fig : str
        File path to save the connectivity matrix image as a .png figure.
    """
    import warnings
    warnings.filterwarnings("ignore")
    import matplotlib
    matplotlib.use("agg")
    import sys
    import pkg_resources
    import yaml
    from matplotlib import pyplot as plt
    from nilearn.plotting import plot_matrix
    from pynets.core import thresholding
    import matplotlib.ticker as mticker

    conn_matrix = thresholding.standardize(conn_matrix)
    conn_matrix_bin = thresholding.binarize(conn_matrix)
    conn_matrix_plt = np.nan_to_num(np.multiply(conn_matrix, conn_matrix_bin))

    try:
        with open(pkg_resources.resource_filename("pynets", "runconfig.yaml"),
                  "r") as stream:
            hardcoded_params = yaml.load(stream)
            try:
                labeling_atlas = \
                hardcoded_params["plotting"]["labeling_atlas"][0]
            except KeyError:
                print(
                    "ERROR: Plotting configuration not successfully extracted"
                    " from runconfig.yaml")
                sys.exit(0)
        stream.close()
        labels = [i[0][labeling_atlas] for i in labels]
    except BaseException:
        pass

    try:
        plot_matrix(
            conn_matrix_plt,
            figure=(10, 10),
            labels=labels,
            vmax=np.percentile(conn_matrix_plt[conn_matrix_plt > 0], 95),
            vmin=0,
            reorder="average",
            auto_fit=True,
            grid=False,
            colorbar=False,
            cmap=cmap,
        )
    except RuntimeWarning:
        print("Connectivity matrix too sparse for plotting...")

    tick_interval = int(np.around(len(labels))) / 20
    plt.axes().yaxis.set_major_locator(mticker.MultipleLocator(tick_interval))
    plt.axes().xaxis.set_major_locator(mticker.MultipleLocator(tick_interval))
    plt.savefig(out_path_fig, dpi=dpi_resolution)
    plt.close()
    return
示例#10
0
def motif_matching(paths,
                   ID,
                   atlas,
                   namer_dir,
                   name_list,
                   metadata_list,
                   multigraph_list_all,
                   graph_path_list_all,
                   rsn=None):
    import networkx as nx
    import numpy as np
    import glob
    from pynets.core import thresholding
    from pynets.stats.netmotifs import compare_motifs
    from sklearn.metrics.pairwise import cosine_similarity
    from pynets.stats.netstats import community_resolution_selection
    try:
        import cPickle as pickle
    except ImportError:
        import _pickle as pickle

    [struct_graph_path, func_graph_path] = paths
    struct_mat = np.load(struct_graph_path)
    func_mat = np.load(func_graph_path)

    if rsn is not None:
        struct_coords_path = glob.glob(
            f"{str(Path(struct_graph_path).parent.parent)}/nodes/{rsn}_coords_rsn.pkl"
        )[0]
        func_coords_path = glob.glob(
            f"{str(Path(func_graph_path).parent.parent)}/nodes/{rsn}_coords_rsn.pkl"
        )[0]
        struct_labels_path = glob.glob(
            f"{str(Path(struct_graph_path).parent.parent)}/nodes/{rsn}_labels_rsn.pkl"
        )[0]
        func_labels_path = glob.glob(
            f"{str(Path(func_graph_path).parent.parent)}/nodes/{rsn}_labels_rsn.pkl"
        )[0]
    else:
        struct_coords_path = glob.glob(
            f"{str(Path(struct_graph_path).parent.parent)}/nodes/*coords.pkl"
        )[0]
        func_coords_path = glob.glob(
            f"{str(Path(func_graph_path).parent.parent)}/nodes/*coords.pkl")[0]
        struct_labels_path = glob.glob(
            f"{str(Path(struct_graph_path).parent.parent)}/nodes/*labels.pkl"
        )[0]
        func_labels_path = glob.glob(
            f"{str(Path(func_graph_path).parent.parent)}/nodes/*labels.pkl")[0]

    with open(struct_coords_path, 'rb') as file_:
        struct_coords = pickle.load(file_)
    with open(func_coords_path, 'rb') as file_:
        func_coords = pickle.load(file_)
    with open(struct_labels_path, 'rb') as file_:
        struct_labels = pickle.load(file_)
    with open(func_labels_path, 'rb') as file_:
        func_labels = pickle.load(file_)

    if func_mat.shape == struct_mat.shape:
        func_mat[~struct_mat.astype('bool')] = 0
        struct_mat[~func_mat.astype('bool')] = 0
        print("Number of edge disagreements after matching: ",
              sum(sum(abs(func_mat - struct_mat))))

        metadata = {}
        assert len(struct_coords) == len(struct_labels) == len(
            func_coords) == len(func_labels) == func_mat.shape[0]
        metadata['coords'] = struct_coords
        metadata['labels'] = struct_labels
        metadata_list.append(metadata)

        struct_mat = np.maximum(struct_mat, struct_mat.T)
        func_mat = np.maximum(func_mat, func_mat.T)
        struct_mat = thresholding.standardize(struct_mat)
        func_mat = thresholding.standardize(func_mat)

        struct_node_comm_aff_mat = community_resolution_selection(
            nx.from_numpy_matrix(np.abs(struct_mat)))[1]

        func_node_comm_aff_mat = community_resolution_selection(
            nx.from_numpy_matrix(np.abs(func_mat)))[1]

        struct_comms = []
        for i in np.unique(struct_node_comm_aff_mat):
            struct_comms.append(struct_node_comm_aff_mat == i)

        func_comms = []
        for i in np.unique(func_node_comm_aff_mat):
            func_comms.append(func_node_comm_aff_mat == i)

        sims = cosine_similarity(struct_comms, func_comms)
        struct_comm = struct_comms[np.argmax(sims, axis=0)[0]]
        func_comm = func_comms[np.argmax(sims, axis=0)[0]]

        comm_mask = np.equal.outer(struct_comm, func_comm).astype(bool)
        struct_mat[~comm_mask] = 0
        func_mat[~comm_mask] = 0
        struct_name = struct_graph_path.split('/')[-1].split('_raw.npy')[0]
        func_name = func_graph_path.split('/')[-1].split('_raw.npy')[0]
        name = f"{ID}_{atlas}_mplx_Layer-1_{struct_name}_Layer-2_{func_name}"
        name_list.append(name)
        struct_mat = np.maximum(struct_mat, struct_mat.T)
        func_mat = np.maximum(func_mat, func_mat.T)
        [mldict, g_dict] = compare_motifs(struct_mat, func_mat, name,
                                          namer_dir)
        multigraph_list_all.append(list(mldict.values())[0])
        graph_path_list = []
        for thr in list(g_dict.keys()):
            multigraph_path_list_dict = {}
            [struct, func] = g_dict[thr]
            struct_out = f"{namer_dir}/struct_{atlas}_{struct_name}.npy"
            func_out = f"{namer_dir}/struct_{atlas}_{func_name}_motif-{thr}.npy"
            np.save(struct_out, struct)
            np.save(func_out, func)
            multigraph_path_list_dict[f"struct_{atlas}_{thr}"] = struct_out
            multigraph_path_list_dict[f"func_{atlas}_{thr}"] = func_out
            graph_path_list.append(multigraph_path_list_dict)
        graph_path_list_all.append(graph_path_list)
    else:
        print(
            f"Skipping {rsn} rsn, since structural and functional graphs are not identical shapes."
        )

    return name_list, metadata_list, multigraph_list_all, graph_path_list_all
示例#11
0
def compare_motifs(struct_mat, func_mat, name, namer_dir, bins=20, N=4):
    '''
    Compare motif structure and population across structural and functional
    graphs to achieve a homeostatic absolute threshold of each that optimizes
    multiplex community detection and analysis.

    Parameters
    ----------
    in_mat : ndarray
        M x M Connectivity matrix
    thr : float
        Absolute threshold [0, 1].
    mlib : list
        List of motif classes.

    Returns
    -------
    mf : ndarray
        1D vector listing the total motifs of size N for each
        class of mlib.

    References
    ----------
    .. [1] Battiston, F., Nicosia, V., Chavez, M., & Latora, V. (2017).
      Multilayer motif analysis of brain networks. Chaos.
      https://doi.org/10.1063/1.4979282

    '''
    from pynets.stats.netmotifs import adaptivethresh
    from pynets.core.thresholding import threshold_absolute
    from pynets.core.thresholding import standardize
    from scipy import spatial
    from nilearn.connectome import sym_matrix_to_vec
    import pandas as pd
    import gc

    mlib = ['1113', '1122', '1223', '2222', '2233', '3333']

    # Standardize structural graph
    struct_mat = standardize(struct_mat)
    dims_struct = struct_mat.shape[0]
    struct_mat[range(dims_struct), range(dims_struct)] = 0
    at_struct = adaptivethresh(struct_mat, float(0.0), mlib, N)
    print("%s%s%s" %
          ('Layer 1 (structural) has: ', np.sum(at_struct), ' total motifs'))

    # Functional graph threshold window
    func_mat = standardize(func_mat)
    dims_func = func_mat.shape[0]
    func_mat[range(dims_func), range(dims_func)] = 0
    tmin_func = func_mat.min()
    tmax_func = func_mat.max()
    threshes_func = np.linspace(tmin_func, tmax_func, bins)

    assert np.all(
        struct_mat == struct_mat.T), "Structural Matrix must be symmetric"
    assert np.all(
        func_mat == func_mat.T), "Functional Matrix must be symmetric"

    # Count motifs
    print("%s%s%s%s" % ('Mining ', N, '-node motifs: ', mlib))
    motif_dict = {}
    motif_dict['struct'] = {}
    motif_dict['func'] = {}

    mat_dict = {}
    mat_dict['struct'] = sym_matrix_to_vec(struct_mat, discard_diagonal=True)
    mat_dict['funcs'] = {}
    for thr_func in threshes_func:
        # Count
        at_func = adaptivethresh(func_mat, float(thr_func), mlib, N)
        motif_dict['struct']["%s%s" %
                             ('thr-', np.round(thr_func, 4))] = at_struct
        motif_dict['func']["%s%s" % ('thr-', np.round(thr_func, 4))] = at_func
        mat_dict['funcs']["%s%s" %
                          ('thr-', np.round(thr_func, 4))] = sym_matrix_to_vec(
                              threshold_absolute(func_mat, thr_func),
                              discard_diagonal=True)

        print("%s%s%s%s%s" %
              ('Layer 2 (functional) with absolute threshold of: ',
               np.round(thr_func,
                        2), ' yields ', np.sum(at_func), ' total motifs'))
        gc.collect()

    df = pd.DataFrame(motif_dict)

    for idx in range(len(df)):
        df.set_value(
            df.index[idx], 'motif_dist',
            spatial.distance.cosine(df['struct'][idx], df['func'][idx]))

    df = df[pd.notnull(df['motif_dist'])]

    for idx in range(len(df)):
        df.set_value(
            df.index[idx], 'graph_dist_cosine',
            spatial.distance.cosine(
                mat_dict['struct'].reshape(-1, 1),
                mat_dict['funcs'][df.index[idx]].reshape(-1, 1)))
        df.set_value(
            df.index[idx], 'graph_dist_correlation',
            spatial.distance.correlation(
                mat_dict['struct'].reshape(-1, 1),
                mat_dict['funcs'][df.index[idx]].reshape(-1, 1)))

    df['struct_func_3333'] = np.zeros(len(df))
    df['struct_func_2233'] = np.zeros(len(df))
    df['struct_func_2222'] = np.zeros(len(df))
    df['struct_func_1223'] = np.zeros(len(df))
    df['struct_func_1122'] = np.zeros(len(df))
    df['struct_func_1113'] = np.zeros(len(df))
    df['struct_3333'] = np.zeros(len(df))
    df['func_3333'] = np.zeros(len(df))
    df['struct_2233'] = np.zeros(len(df))
    df['func_2233'] = np.zeros(len(df))
    df['struct_2222'] = np.zeros(len(df))
    df['func_2222'] = np.zeros(len(df))
    df['struct_1223'] = np.zeros(len(df))
    df['func_1223'] = np.zeros(len(df))
    df['struct_1122'] = np.zeros(len(df))
    df['func_1122'] = np.zeros(len(df))
    df['struct_1113'] = np.zeros(len(df))
    df['func_1113'] = np.zeros(len(df))

    for idx in range(len(df)):
        df.set_value(df.index[idx], 'struct_3333', df['struct'][idx][-1])
        df.set_value(df.index[idx], 'func_3333', df['func'][idx][-1])

        df.set_value(df.index[idx], 'struct_2233', df['struct'][idx][-2])
        df.set_value(df.index[idx], 'func_2233', df['func'][idx][-2])

        df.set_value(df.index[idx], 'struct_2222', df['struct'][idx][-3])
        df.set_value(df.index[idx], 'func_2222', df['func'][idx][-3])

        df.set_value(df.index[idx], 'struct_1223', df['struct'][idx][-4])
        df.set_value(df.index[idx], 'func_1223', df['func'][idx][-4])

        df.set_value(df.index[idx], 'struct_1122', df['struct'][idx][-5])
        df.set_value(df.index[idx], 'func_1122', df['func'][idx][-5])

        df.set_value(df.index[idx], 'struct_1113', df['struct'][idx][-6])
        df.set_value(df.index[idx], 'func_1113', df['func'][idx][-6])

    df['struct_func_3333'] = np.abs(df['struct_3333'] - df['func_3333'])
    df['struct_func_2233'] = np.abs(df['struct_2233'] - df['func_2233'])
    df['struct_func_2222'] = np.abs(df['struct_2222'] - df['func_2222'])
    df['struct_func_1223'] = np.abs(df['struct_1223'] - df['func_1223'])
    df['struct_func_1122'] = np.abs(df['struct_1122'] - df['func_1122'])
    df['struct_func_1113'] = np.abs(df['struct_1113'] - df['func_1113'])

    df = df.drop(columns=['struct', 'func'])

    df = df.loc[~(df == 0).all(axis=1)]

    df = df.sort_values(by=[
        'motif_dist', 'graph_dist_cosine', 'graph_dist_correlation',
        'struct_func_3333', 'struct_func_2233', 'struct_func_2222',
        'struct_func_1223', 'struct_func_1122', 'struct_func_1113',
        'struct_3333', 'func_3333', 'struct_2233', 'func_2233', 'struct_2222',
        'func_2222', 'struct_1223', 'func_1223', 'struct_1122', 'func_1122',
        'struct_1113', 'func_1113'
    ],
                        ascending=[
                            True, True, False, False, False, False, False,
                            False, False, False, False, False, False, False,
                            False, False, False, False, False, False, False
                        ])

    # Take the top 25th percentile
    df = df.head(int(0.25 * len(df)))
    best_threshes = []
    best_mats = []
    best_multigraphs = []
    for key in list(df.index):
        func_mat_tmp = func_mat.copy()
        struct_mat_tmp = struct_mat.copy()
        struct_thr = float(key.split('-')[-1])
        func_thr = float(key.split('-')[-1])
        best_threshes.append(str(func_thr))

        func_mat_tmp[func_mat_tmp < func_thr] = 0
        struct_mat_tmp[struct_mat_tmp < struct_thr] = 0
        best_mats.append((func_mat_tmp, struct_mat_tmp))

        mG = build_mx_multigraph(func_mat, struct_mat, f"{name}_{key}",
                                 namer_dir)
        best_multigraphs.append(mG)

    mg_dict = dict(zip(best_threshes, best_multigraphs))
    g_dict = dict(zip(best_threshes, best_mats))

    return mg_dict, g_dict
示例#12
0
 def test_standardize(x):
     w = thresholding.standardize(x)
     assert w is not None
示例#13
0
def plot_community_conn_mat(conn_matrix,
                            labels,
                            out_path_fig_comm,
                            community_aff,
                            cmap,
                            dpi_resolution=300):
    """
    Plot a community-parcellated connectivity matrix.

    Parameters
    ----------
    conn_matrix : array
        NxN matrix.
    labels : list
        List of string labels corresponding to ROI nodes.
    out_path_fig_comm : str
        File path to save the community-parcellated connectivity matrix image as a .png figure.
    community_aff : array
        Community-affiliation vector.
    """
    import matplotlib
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    matplotlib.use('agg')
    #from pynets import thresholding
    from nilearn.plotting import plot_matrix
    from pynets.core import thresholding

    conn_matrix_bin = thresholding.binarize(conn_matrix)
    conn_matrix = thresholding.standardize(conn_matrix)
    conn_matrix_plt = np.nan_to_num(np.multiply(conn_matrix, conn_matrix_bin))

    sorting_array = sorted(range(len(community_aff)),
                           key=lambda k: community_aff[k])
    sorted_conn_matrix = conn_matrix[sorting_array, :]
    sorted_conn_matrix = sorted_conn_matrix[:, sorting_array]
    rois_num = sorted_conn_matrix.shape[0]
    if rois_num < 100:
        try:
            plot_matrix(conn_matrix_plt,
                        figure=(10, 10),
                        labels=labels,
                        vmax=np.abs(np.max(conn_matrix_plt)),
                        vmin=-np.abs(np.max(conn_matrix_plt)),
                        reorder=False,
                        auto_fit=True,
                        grid=False,
                        colorbar=False,
                        cmap=cmap)
        except RuntimeWarning:
            print('Connectivity matrix too sparse for plotting...')
    else:
        try:
            plot_matrix(conn_matrix_plt,
                        figure=(10, 10),
                        vmax=np.abs(np.max(conn_matrix_plt)),
                        vmin=-np.abs(np.max(conn_matrix_plt)),
                        auto_fit=True,
                        grid=False,
                        colorbar=False,
                        cmap=cmap)
        except RuntimeWarning:
            print('Connectivity matrix too sparse for plotting...')

    ax = plt.gca()
    total_size = 0
    for community in np.unique(community_aff):
        size = sum(sorted(community_aff) == community)
        ax.add_patch(
            patches.Rectangle((total_size, total_size),
                              size,
                              size,
                              fill=False,
                              edgecolor='black',
                              alpha=None,
                              linewidth=1))
        total_size += size

    plt.savefig(out_path_fig_comm, dpi=dpi_resolution)
    plt.close()
    return
示例#14
0
def motif_matching(
    paths,
    ID,
    atlas,
    namer_dir,
    name_list,
    metadata_list,
    multigraph_list_all,
    graph_path_list_all,
    rsn=None,
):
    import networkx as nx
    import numpy as np
    import glob
    import pickle
    from pynets.core import thresholding
    from pynets.stats.netmotifs import compare_motifs
    from sklearn.metrics.pairwise import cosine_similarity
    from pynets.stats.netstats import community_resolution_selection
    from graspy.utils import remove_loops, symmetrize, get_lcc
    from pynets.core.nodemaker import get_brainnetome_node_attributes

    [struct_graph_path, func_graph_path] = paths
    struct_mat = np.load(struct_graph_path)
    func_mat = np.load(func_graph_path)

    [struct_coords, struct_labels, struct_label_intensities] = \
        get_brainnetome_node_attributes(glob.glob(
        f"{str(Path(struct_graph_path).parent.parent)}/nodes/*.json"),
        struct_mat.shape[0])

    [func_coords, func_labels, func_label_intensities] = \
        get_brainnetome_node_attributes(glob.glob(
        f"{str(Path(func_graph_path).parent.parent)}/nodes/*.json"),
        func_mat.shape[0])

    # Find intersecting nodes across modalities (i.e. assuming the same
    # parcellation, but accomodating for the possibility of dropped nodes)
    diff1 = list(set(struct_label_intensities) - set(func_label_intensities))
    diff2 = list(set(func_label_intensities) - set(struct_label_intensities))
    G_struct = nx.from_numpy_array(struct_mat)
    G_func = nx.from_numpy_array(func_mat)

    bad_idxs = []
    for val in diff1:
        bad_idxs.append(struct_label_intensities.index(val))
        bad_idxs = sorted(list(set(bad_idxs)), reverse=True)
        if type(struct_coords) is np.ndarray:
            struct_coords = list(tuple(x) for x in struct_coords)
    for j in bad_idxs:
        G_struct.remove_node(j)
        print(f"Removing: {(struct_labels[j], struct_coords[j])}...")
        del struct_labels[j], struct_coords[j]

    bad_idxs = []
    for val in diff2:
        bad_idxs.append(func_label_intensities.index(val))
        bad_idxs = sorted(list(set(bad_idxs)), reverse=True)
        if type(func_coords) is np.ndarray:
            func_coords = list(tuple(x) for x in func_coords)
    for j in bad_idxs:
        G_func.remove_node(j)
        print(f"Removing: {(func_labels[j], func_coords[j])}...")
        del func_labels[j], func_coords[j]

    struct_mat = nx.to_numpy_array(G_struct)
    func_mat = nx.to_numpy_array(G_func)

    struct_mat = thresholding.autofix(symmetrize(remove_loops(struct_mat)))

    func_mat = thresholding.autofix(symmetrize(remove_loops(func_mat)))

    if func_mat.shape == struct_mat.shape:
        func_mat[~struct_mat.astype("bool")] = 0
        struct_mat[~func_mat.astype("bool")] = 0
        print(
            "Edge disagreements after matching: ",
            sum(sum(abs(func_mat - struct_mat))),
        )

        metadata = {}
        assert (
            len(struct_coords)
            == len(struct_labels)
            == len(func_coords)
            == len(func_labels)
            == func_mat.shape[0]
        )
        metadata["coords"] = struct_coords
        metadata["labels"] = struct_labels
        metadata_list.append(metadata)

        struct_mat = np.maximum(struct_mat, struct_mat.T)
        func_mat = np.maximum(func_mat, func_mat.T)
        struct_mat = thresholding.standardize(struct_mat)
        func_mat = thresholding.standardize(func_mat)

        struct_node_comm_aff_mat = community_resolution_selection(
            nx.from_numpy_matrix(np.abs(struct_mat))
        )[1]

        func_node_comm_aff_mat = community_resolution_selection(
            nx.from_numpy_matrix(np.abs(func_mat))
        )[1]

        struct_comms = []
        for i in np.unique(struct_node_comm_aff_mat):
            struct_comms.append(struct_node_comm_aff_mat == i)

        func_comms = []
        for i in np.unique(func_node_comm_aff_mat):
            func_comms.append(func_node_comm_aff_mat == i)

        sims = cosine_similarity(struct_comms, func_comms)
        try:
            struct_comm = struct_comms[np.argmax(sims, axis=0)[0]]
        except BaseException:
            print('Matching by structural communities failed...')
            struct_comm = struct_mat
        try:
            func_comm = func_comms[np.argmax(sims, axis=0)[0]]
        except BaseException:
            print('Matching by functional communities failed...')
            func_comm = func_mat

        comm_mask = np.equal.outer(struct_comm, func_comm).astype(bool)

        try:
            assert comm_mask.shape == struct_mat.shape == func_mat.shape
        except AssertionError as e:
            e.args += (comm_mask, comm_mask.shape, struct_mat,
                       struct_mat.shape, func_mat, func_mat.shape)

        try:
            struct_mat[~comm_mask] = 0
        except BaseException:
            print('Skipping community masking...')
        try:
            func_mat[~comm_mask] = 0
        except BaseException:
            print('Skipping community masking...')

        struct_name = struct_graph_path.split("/rawgraph_"
                                              )[-1].split(".npy")[0]
        func_name = func_graph_path.split("/rawgraph_")[-1].split(".npy")[0]
        name = f"sub-{ID}_{atlas}_mplx_Layer-1_{struct_name}_" \
               f"Layer-2_{func_name}"
        name_list.append(name)
        struct_mat = np.maximum(struct_mat, struct_mat.T)
        func_mat = np.maximum(func_mat, func_mat.T)
        try:
            [mldict, g_dict] = compare_motifs(
                struct_mat, func_mat, name, namer_dir)
        except BaseException:
            print(f"Adaptive thresholding by motif comparisons failed "
                  f"for {name}. This usually happens when no motifs are found")
            return [], [], [], []

        multigraph_list_all.append(list(mldict.values())[0])
        graph_path_list = []
        for thr in list(g_dict.keys()):
            multigraph_path_list_dict = {}
            [struct, func] = g_dict[thr]
            struct_out = f"{namer_dir}/struct_{atlas}_{struct_name}.npy"
            func_out = f"{namer_dir}/struct_{atlas}_{func_name}_" \
                       f"motif-{thr}.npy"
            np.save(struct_out, struct)
            np.save(func_out, func)
            multigraph_path_list_dict[f"struct_{atlas}_{thr}"] = struct_out
            multigraph_path_list_dict[f"func_{atlas}_{thr}"] = func_out
            graph_path_list.append(multigraph_path_list_dict)
        graph_path_list_all.append(graph_path_list)
    else:
        print(
            f"Skipping {rsn} rsn, since structural and functional graphs are "
            f"not identical shapes."
        )

    return name_list, metadata_list, multigraph_list_all, graph_path_list_all
示例#15
0
def plot_all_struct_func(mG_path, namer_dir, name, modality_paths, metadata):
    """
    Plot adjacency matrix and glass brain for structural-functional multiplex connectome.

    Parameters
    ----------
    mG_path : str
        A gpickle file containing a a MultilayerGraph object (See https://github.com/nkoub/multinetx).
    namer_dir : str
        Path to output directory for multiplex data.
    name : str
        Concatenation of multimodal graph filenames.
    modality_paths : tuple
       A tuple of filepath strings to the raw structural and raw functional connectome graph files (.npy).
    metadata : dict
        Dictionary coontaining coords and labels shared by each layer of the multilayer graph.
    """
    import numpy as np
    import multinetx as mx
    import matplotlib
    matplotlib.use('agg')
    import pkg_resources
    import networkx as nx
    import yaml
    import sys
    from matplotlib import pyplot as plt
    from nilearn import plotting as niplot
    from pynets.core import thresholding
    from pynets.plotting.plot_gen import create_gb_palette

    coords = metadata['coords']
    labels = metadata['labels']

    ch2better_loc = pkg_resources.resource_filename(
        "pynets", "templates/ch2better.nii.gz")

    with open(pkg_resources.resource_filename("pynets", "runconfig.yaml"),
              'r') as stream:
        hardcoded_params = yaml.load(stream)
        try:
            color_theme_func = hardcoded_params['plotting']['functional'][
                'glassbrain']['color_theme'][0]
            color_theme_struct = hardcoded_params['plotting']['structural'][
                'glassbrain']['color_theme'][0]
            glassbrain = hardcoded_params['plotting']['glassbrain'][0]
            adjacency = hardcoded_params['plotting']['adjacency'][0]
            dpi_resolution = hardcoded_params['plotting']['dpi'][0]
        except KeyError:
            print(
                'ERROR: Plotting configuration not successfully extracted from runconfig.yaml'
            )
            sys.exit(0)
    stream.close()

    [struct_mat_path, func_mat_path] = modality_paths
    struct_mat, func_mat = [np.load(struct_mat_path), np.load(func_mat_path)]

    if adjacency is True:
        # Multiplex adjacency
        mG = nx.read_gpickle(mG_path)

        fig = plt.figure(figsize=(15, 5))
        ax1 = fig.add_subplot(121)
        adj = thresholding.standardize(
            mx.adjacency_matrix(mG, weight='weight').todense())
        [z_min, z_max] = np.abs(adj).min(), np.abs(adj).max()

        adj[adj == 0] = np.nan

        ax1.imshow(adj,
                   origin='lower',
                   interpolation='nearest',
                   cmap=plt.cm.RdBu,
                   vmin=0.01,
                   vmax=z_max)
        ax1.set_title('Supra-Adjacency Matrix')

        ax2 = fig.add_subplot(122)
        ax2.axis('off')
        ax2.set_title(f"Functional-Structural Multiplex Connectome")

        pos = mx.get_position(mG,
                              mx.fruchterman_reingold_layout(mG.get_layer(0)),
                              layer_vertical_shift=1.0,
                              layer_horizontal_shift=0.0,
                              proj_angle=7)
        edge_intensities = []
        for a, b, w in mG.edges(data=True):
            if w != {}:
                edge_intensities.append(w['weight'])
            else:
                edge_intensities.append(0)

        node_centralities = list(
            nx.algorithms.eigenvector_centrality(mG, weight='weight').values())
        mx.draw_networkx(mG,
                         pos=pos,
                         ax=ax2,
                         node_size=100,
                         with_labels=False,
                         edge_color=edge_intensities,
                         node_color=node_centralities,
                         edge_vmin=z_min,
                         edge_vmax=z_max,
                         dim=3,
                         font_size=6,
                         widths=3,
                         alpha=0.7,
                         cmap=plt.cm.RdBu)
        plt.savefig(f"{namer_dir}/{name[:200]}supra_adj.png",
                    dpi=dpi_resolution)

    if glassbrain is True:
        # Multiplex glass brain
        views = ['x', 'y', 'z']
        connectome = niplot.plot_connectome(np.zeros(shape=(1, 1)),
                                            [(0, 0, 0)],
                                            node_size=0.0001,
                                            black_bg=True)
        connectome.add_overlay(ch2better_loc, alpha=0.50, cmap=plt.cm.gray)

        [struct_mat, _, _, _, edge_sizes_struct, _, _, coords,
         labels] = create_gb_palette(struct_mat,
                                     color_theme_struct,
                                     coords,
                                     labels,
                                     prune=False)

        connectome.add_graph(struct_mat,
                             coords,
                             edge_threshold='50%',
                             edge_cmap=plt.cm.binary,
                             node_size=1,
                             edge_kwargs={
                                 'alpha': 0.50,
                                 "lineStyle": 'dashed'
                             },
                             node_kwargs={'alpha': 0.95},
                             edge_vmax=float(1),
                             edge_vmin=float(1))

        for view in views:
            mod_lines = []
            for line, edge_size in list(
                    zip(connectome.axes[view].ax.lines, edge_sizes_struct)):
                line.set_lw(edge_size)
                mod_lines.append(line)
            connectome.axes[view].ax.lines = mod_lines

        [
            func_mat, clust_pal_edges, clust_pal_nodes, node_sizes,
            edge_sizes_func, z_min, z_max, coords, labels
        ] = create_gb_palette(func_mat,
                              color_theme_func,
                              coords,
                              labels,
                              prune=False)
        connectome.add_graph(func_mat,
                             coords,
                             edge_threshold='50%',
                             edge_cmap=clust_pal_edges,
                             edge_kwargs={'alpha': 0.75},
                             edge_vmax=float(z_max),
                             edge_vmin=float(z_min),
                             node_size=node_sizes,
                             node_color=clust_pal_nodes)

        for view in views:
            mod_lines = []
            for line, edge_size in list(
                    zip(
                        connectome.axes[view].ax.
                        lines[len(edge_sizes_struct):], edge_sizes_func)):
                line.set_lw(edge_size)
                mod_lines.append(line)
            connectome.axes[view].ax.lines[len(edge_sizes_struct):] = mod_lines

        connectome.savefig(f"{namer_dir}/{name[:200]}glassbrain_mplx.png",
                           dpi=dpi_resolution)

    return
示例#16
0
文件: multiplex.py 项目: dPys/PyNets
def matching(
    paths,
    atlas,
    namer_dir,
):
    import glob
    import networkx as nx
    import numpy as np
    from pynets.core import thresholding
    from pynets.statistics.utils import parse_closest_ixs
    from graspologic.utils import remove_loops, symmetrize, \
        multigraph_lcc_intersection

    [dwi_graph_path, func_graph_path] = paths
    dwi_mat = np.load(dwi_graph_path)
    func_mat = np.load(func_graph_path)
    dwi_mat = thresholding.autofix(symmetrize(remove_loops(dwi_mat)))
    func_mat = thresholding.autofix(symmetrize(remove_loops(func_mat)))
    dwi_mat = thresholding.standardize(dwi_mat)
    func_mat = thresholding.standardize(func_mat)

    node_dict_dwi = parse_closest_ixs(
        glob.glob(f"{str(Path(dwi_graph_path).parent.parent)}"
                  f"/nodes/*.json"), dwi_mat.shape[0])[1]

    node_dict_func = parse_closest_ixs(
        glob.glob(f"{str(Path(func_graph_path).parent.parent)}"
                  f"/nodes/*.json"), func_mat.shape[0])[1]

    G_dwi = nx.from_numpy_array(dwi_mat)
    nx.set_edge_attributes(G_dwi, 'structural',
                           nx.get_edge_attributes(G_dwi, 'weight').values())
    nx.set_node_attributes(G_dwi, dict(node_dict_dwi), name='dwi')
    #G_dwi.nodes(data=True)

    G_func = nx.from_numpy_array(func_mat)
    nx.set_edge_attributes(G_func, 'functional',
                           nx.get_edge_attributes(G_func, 'weight').values())
    nx.set_node_attributes(G_func, dict(node_dict_func), name='func')
    #G_func.nodes(data=True)

    R = G_dwi.copy()
    R.remove_nodes_from(n for n in G_dwi if n not in G_func)
    R.remove_edges_from(e for e in G_dwi.edges if e not in G_func.edges)
    G_dwi = R.copy()

    R = G_func.copy()
    R.remove_nodes_from(n for n in G_func if n not in G_dwi)
    R.remove_edges_from(e for e in G_func.edges if e not in G_dwi.edges)
    G_func = R.copy()

    [G_dwi, G_func] = multigraph_lcc_intersection([G_dwi, G_func])

    def writeJSON(metadata_str, outputdir):
        import json
        import uuid
        modality = metadata_str.split('modality-')[1].split('_')[0]
        metadata_list = [
            i for i in metadata_str.split('modality-')[1].split('_')
            if '-' in i
        ]
        hash = str(uuid.uuid4())
        filename = f"{outputdir}/sidecar_modality-{modality}_{hash}.json"
        metadata_dict = {}
        for meta in metadata_list:
            k, v = meta.split('-')
            metadata_dict[k] = v
        with open(filename, 'w+') as jsonfile:
            json.dump(metadata_dict, jsonfile, indent=4)
        jsonfile.close()
        return hash

    dwi_name = dwi_graph_path.split("/")[-1].split(".npy")[0]
    func_name = func_graph_path.split("/")[-1].split(".npy")[0]

    dwi_hash = writeJSON(dwi_name, namer_dir)
    func_hash = writeJSON(func_name, namer_dir)

    name = f"{atlas}_mplx_layer1-dwi_ensemble-{dwi_hash}_" \
           f"layer2-func_ensemble-{func_hash}"

    dwi_opt, func_opt, best_mi = optimize_mutual_info(
        nx.to_numpy_array(G_dwi), nx.to_numpy_array(G_func), bins=50)

    func_mat_final = list(func_opt.values())[0]
    dwi_mat_final = list(dwi_opt.values())[0]
    G_dwi_final = nx.from_numpy_array(dwi_mat_final)
    G_func_final = nx.from_numpy_array(func_mat_final)

    G_multi = nx.OrderedMultiGraph(nx.compose(G_dwi_final, G_func_final))

    out_name = f"{name}_matchthr-{list(dwi_opt.keys())[0]}_" \
               f"{list(func_opt.keys())[0]}"
    mG = build_mx_multigraph(nx.to_numpy_array(G_func_final),
                             nx.to_numpy_array(G_dwi_final), out_name,
                             namer_dir)

    mG_nx = f"{namer_dir}/{out_name}.gpickle"
    nx.write_gpickle(G_multi, mG_nx)

    dwi_file_out = f"{namer_dir}/{dwi_name}.npy"
    func_file_out = f"{namer_dir}/{func_name}.npy"
    np.save(dwi_file_out, dwi_mat_final)
    np.save(func_file_out, func_mat_final)
    return mG_nx, mG, dwi_file_out, func_file_out