Ejemplo n.º 1
0
def draw_colocalization(G,
                        seed_nodes_1,
                        seed_nodes_2,
                        edge_cmap=plt.cm.autumn_r,
                        export_file='colocalization.json',
                        export_network=False,
                        highlight_nodes=None,
                        k=None,
                        largest_connected_component=False,
                        node_cmap=plt.cm.autumn_r,
                        node_size=10,
                        num_nodes=None,
                        physics_enabled=False,
                        Wprime=None,
                        **kwargs):
    '''
    Implements and displays the network propagation for a given graph and two
    sets of seed nodes. Additional kwargs are passed to visJS_module.

    Inputs:
        - G: a networkX graph
        - seed_nodes_1: first set of nodes on which to initialize the simulation
        - seed_nodes_2: second set of nodes on which to initialize the simulation
        - edge_cmap: matplotlib colormap for edges, optional, default: matplotlib.cm.autumn_r
        - export_file: JSON file to export graph data, default: 'colocalization.json'
        - export_network: export network to Cytoscape, default: False
        - highlight_nodes: list of nodes to place borders around, default: None
        - k: float, optional, optimal distance between nodes for nx.spring_layout(), default: None
        - largest_connected_component: boolean, optional, whether or not to display largest_connected_component,
                                       default: False
        - node_cmap: matplotlib colormap for nodes, optional, default: matplotlib.cm.autumn_r
        - node_size: size of nodes, default: 10
        - num_nodes: the number of the hottest nodes to graph, default: None (all nodes will be graphed)
        - physics_enabled: enable physics simulation, default: False
        - Wprime:  Normalized adjacency matrix (from normalized_adj_matrix)

    Returns:
        - VisJS html network plot (iframe) of the colocalization.
    '''

    # check for invalid nodes in seed_nodes
    invalid_nodes = [(node, 'seed_nodes_1') for node in seed_nodes_1
                     if node not in G.nodes()]
    invalid_nodes.extend([(node, 'seed_nodes_2') for node in seed_nodes_2
                          if node not in G.nodes()])
    for node in invalid_nodes:
        print('Node {} in {} not in graph'.format(node[0], node[1]))
    if invalid_nodes:
        return

    # perform the colocalization
    if Wprime is None:
        Wprime = normalized_adj_matrix(G)
    prop_graph_1 = network_propagation(G, Wprime, seed_nodes_1).to_dict()
    prop_graph_2 = network_propagation(G, Wprime, seed_nodes_2).to_dict()
    prop_graph = {
        node: (prop_graph_1[node] * prop_graph_2[node])
        for node in prop_graph_1
    }
    nx.set_node_attributes(G, name='node_heat', values=prop_graph)

    # find top num_nodes hottest nodes and connected component if requested
    G = set_num_nodes(G, num_nodes)
    if largest_connected_component:
        G = max(nx.connected_component_subgraphs(G), key=len)
    nodes = list(G.nodes())
    edges = list(G.edges())

    # check for empty nodes and edges after getting subgraph of G
    if not nodes:
        print('There are no nodes in the graph. Try increasing num_nodes.')
        return
    if not edges:
        print('There are no edges in the graph. Try increasing num_nodes.')
        return

    # set position of each node
    if k is None:
        pos = nx.spring_layout(G)
    else:
        pos = nx.spring_layout(G, k=k)

    xpos, ypos = zip(*pos.values())
    nx.set_node_attributes(G,
                           name='xpos',
                           values=dict(
                               zip(pos.keys(), [x * 1000 for x in xpos])))
    nx.set_node_attributes(G,
                           name='ypos',
                           values=dict(
                               zip(pos.keys(), [y * 1000 for y in ypos])))

    # set the border width of nodes
    if 'node_border_width' not in kwargs.keys():
        kwargs['node_border_width'] = 2

    border_width = {}
    for n in nodes:
        if n in seed_nodes_1 or n in seed_nodes_2:
            border_width[n] = kwargs['node_border_width']
        elif highlight_nodes is not None and n in highlight_nodes:
            border_width[n] = kwargs['node_border_width']
        else:
            border_width[n] = 0

    nx.set_node_attributes(G, name='nodeOutline', values=border_width)

    # set the shape of each node
    nodes_shape = []
    for node in G.nodes():
        if node in seed_nodes_1:
            nodes_shape.append('triangle')
        elif node in seed_nodes_2:
            nodes_shape.append('square')
        else:
            nodes_shape.append('dot')
    node_to_shape = dict(zip(G.nodes(), nodes_shape))
    nx.set_node_attributes(G, name='nodeShape', values=node_to_shape)

    # add a field for node labels
    if highlight_nodes:
        node_labels = {}
        for node in nodes:
            if node in seed_nodes_1 or n in seed_nodes_2:
                node_labels[node] = str(node)
            elif node in highlight_nodes:
                node_labels[node] = str(node)
            else:
                node_labels[node] = ''
    else:
        node_labels = {n: str(n) for n in nodes}

    nx.set_node_attributes(G, name='nodeLabel', values=node_labels)

    # set the title of each node
    node_titles = [
        str(node[0]) + '<br/>heat = ' + str(round(node[1]['node_heat'], 10))
        for node in G.nodes(data=True)
    ]
    node_titles = dict(zip(nodes, node_titles))
    nx.set_node_attributes(G, name='nodeTitle', values=node_titles)

    # set the color of each node
    node_to_color = visJS_module.return_node_to_color(
        G,
        field_to_map='node_heat',
        cmap=node_cmap,
        color_vals_transform='log')

    # set heat value of edge based off hottest connecting node's value
    node_attr = nx.get_node_attributes(G, 'node_heat')
    edge_weights = {}
    for e in edges:
        if node_attr[e[0]] > node_attr[e[1]]:
            edge_weights[e] = node_attr[e[0]]
        else:
            edge_weights[e] = node_attr[e[1]]

    nx.set_edge_attributes(G, name='edge_weight', values=edge_weights)

    # set the color of each edge
    edge_to_color = visJS_module.return_edge_to_color(
        G,
        field_to_map='edge_weight',
        cmap=edge_cmap,
        color_vals_transform='log')

    # create the nodes_dict with all relevant fields
    nodes_dict = [{
        'id': str(n),
        'border_width': border_width[n],
        'degree': G.degree(n),
        'color': node_to_color[n],
        'node_label': node_labels[n],
        'node_size': node_size,
        'node_shape': node_to_shape[n],
        'title': node_titles[n],
        'x': np.float64(pos[n][0]).item() * 1000,
        'y': np.float64(pos[n][1]).item() * 1000
    } for n in nodes]

    # map nodes to indices for source/target in edges
    node_map = dict(zip(nodes, range(len(nodes))))

    # create the edges_dict with all relevant fields
    edges_dict = [{
        'source': node_map[edges[i][0]],
        'target': node_map[edges[i][1]],
        'color': edge_to_color[edges[i]]
    } for i in range(len(edges))]

    # set node_size_multiplier to increase node size as graph gets smaller
    if 'node_size_multiplier' not in kwargs.keys():
        if len(nodes) > 500:
            kwargs['node_size_multiplier'] = 1
        elif len(nodes) > 200:
            kwargs['node_size_multiplier'] = 3
        else:
            kwargs['node_size_multiplier'] = 5

    kwargs['physics_enabled'] = physics_enabled

    # if node hovering color not set, set default to black
    if 'node_color_hover_background' not in kwargs.keys():
        kwargs['node_color_hover_background'] = 'black'

    # node size determined by size in nodes_dict, not by id
    if 'node_size_field' not in kwargs.keys():
        kwargs['node_size_field'] = 'node_size'

    # node label determined by value in nodes_dict
    if 'node_label_field' not in kwargs.keys():
        kwargs['node_label_field'] = 'node_label'

    # export the network to JSON for Cytoscape
    if export_network:
        node_colors = map_node_to_color(G, 'node_heat', True)
        nx.set_node_attributes(G, name='nodeColor', values=node_colors)
        edge_colors = map_edge_to_color(G, 'edge_weight', True)
        nx.set_edge_attributes(G, name='edgeColor', values=edge_colors)
        visJS_module.export_to_cytoscape(G=G, export_file=export_file)

    return visJS_module.visjs_network(nodes_dict, edges_dict, **kwargs)
Ejemplo n.º 2
0
def draw_graph_overlap(G1,
                       G2,
                       edge_cmap=plt.cm.coolwarm,
                       export_file='graph_overlap.json',
                       export_network=False,
                       highlight_nodes=None,
                       k=None,
                       node_cmap=plt.cm.autumn,
                       node_name_1='graph 1',
                       node_name_2='graph 2',
                       node_size=10,
                       physics_enabled=False,
                       **kwargs):
    '''
    Takes two networkX graphs and displays their overlap, where intersecting
    nodes are triangles. Additional kwargs are passed to visjs_module.

    Inputs:
        - G1: a networkX graph
        - G2: a networkX graph
        - edge_cmap: matplotlib colormap for edges, default: matplotlib.cm.coolwarm
        - export_file: JSON file to export graph data, default: 'graph_overlap.json'
        - export_network: export network to Cytoscape, default: False
        - highlight_nodes: list of nodes to place borders around, default: None
        - k: float, optimal distance between nodes for nx.spring_layout(), default: None
        - node_cmap: matplotlib colormap for nodes, default: matplotlib.cm.autumn
        - node_name_1: string to name first graph's nodes, default: 'graph 1'
        - node_name_2: string to name second graph's nodes, default: 'graph 2'
        - node_size: size of nodes, default: 10
        - physics_enabled: enable physics simulation, default: False

    Returns:
        - VisJS html network plot (iframe) of the graph overlap.
    '''

    G_overlap = create_graph_overlap(G1, G2, node_name_1, node_name_2)

    # create nodes dict and edges dict for input to visjs
    nodes = list(G_overlap.nodes())
    edges = list(G_overlap.edges())

    # set the position of each node
    if k is None:
        pos = nx.spring_layout(G_overlap)
    else:
        pos = nx.spring_layout(G_overlap, k=k)

    xpos, ypos = zip(*pos.values())
    nx.set_node_attributes(G_overlap,
                           name='xpos',
                           values=dict(
                               zip(pos.keys(), [x * 1000 for x in xpos])))
    nx.set_node_attributes(G_overlap,
                           name='ypos',
                           values=dict(
                               zip(pos.keys(), [y * 1000 for y in ypos])))

    # set the border width of nodes
    if 'node_border_width' not in kwargs.keys():
        kwargs['node_border_width'] = 2

    border_width = {}
    for n in nodes:
        if highlight_nodes is not None and n in highlight_nodes:
            border_width[n] = kwargs['node_border_width']
        else:
            border_width[n] = 0

    nx.set_node_attributes(G_overlap, name='nodeOutline', values=border_width)

    # set the shape of each node
    nodes_shape = []
    for node in G_overlap.nodes(data=True):
        if node[1]['node_overlap'] == 0:
            nodes_shape.append('dot')
        elif node[1]['node_overlap'] == 2:
            nodes_shape.append('square')
        elif node[1]['node_overlap'] == 1:
            nodes_shape.append('triangle')
    node_to_shape = dict(zip(G_overlap.nodes(), nodes_shape))
    nx.set_node_attributes(G_overlap, name='nodeShape', values=node_to_shape)

    # set the node label of each node
    if highlight_nodes:
        node_labels = {}
        for node in nodes:
            if node in highlight_nodes:
                node_labels[node] = str(node)
            else:
                node_labels[node] = ''
    else:
        node_labels = {n: str(n) for n in nodes}

    nx.set_node_attributes(G_overlap, name='nodeLabel', values=node_labels)

    # set the node title of each node
    node_titles = [
        node[1]['node_name_membership'] + '<br/>' + str(node[0])
        for node in G_overlap.nodes(data=True)
    ]
    node_titles = dict(zip(G_overlap.nodes(), node_titles))
    nx.set_node_attributes(G_overlap, name='nodeTitle', values=node_titles)

    # set color of each node
    node_to_color = visJS_module.return_node_to_color(
        G_overlap,
        field_to_map='node_overlap',
        cmap=node_cmap,
        color_max_frac=.9,
        color_min_frac=.1)

    # set color of each edge
    edge_to_color = visJS_module.return_edge_to_color(
        G_overlap, field_to_map='edge_weight', cmap=edge_cmap, alpha=.3)

    # create the nodes_dict with all relevant fields
    nodes_dict = [{
        'id': str(n),
        'border_width': border_width[n],
        'color': node_to_color[n],
        'degree': G_overlap.degree(n),
        'node_label': node_labels[n],
        'node_shape': node_to_shape[n],
        'node_size': node_size,
        'title': node_titles[n],
        'x': np.float64(pos[n][0]).item() * 1000,
        'y': np.float64(pos[n][1]).item() * 1000
    } for n in nodes]

    # map nodes to indices for source/target in edges
    node_map = dict(zip(nodes, range(len(nodes))))

    # create the edges_dict with all relevant fields
    edges_dict = [{
        'source': node_map[edges[i][0]],
        'target': node_map[edges[i][1]],
        'color': edge_to_color[edges[i]]
    } for i in range(len(edges))]

    # set node_size_multiplier to increase node size as graph gets smaller
    if 'node_size_multiplier' not in kwargs.keys():
        if len(nodes) > 500:
            kwargs['node_size_multiplier'] = 3
        elif len(nodes) > 200:
            kwargs['node_size_multiplier'] = 5
        else:
            kwargs['node_size_multiplier'] = 7

    kwargs['physics_enabled'] = physics_enabled

    # if node hovering color not set, set default to black
    if 'node_color_hover_background' not in kwargs.keys():
        kwargs['node_color_hover_background'] = 'black'

    # node size determined by size in nodes_dict, not by id
    if 'node_size_field' not in kwargs.keys():
        kwargs['node_size_field'] = 'node_size'

    # node label determined by value in nodes_dict
    if 'node_label_field' not in kwargs.keys():
        kwargs['node_label_field'] = 'node_label'

    # export the network to JSON for Cytoscape
    if export_network:
        node_colors = map_node_to_color(G_overlap, 'node_overlap', False)
        nx.set_node_attributes(G_overlap, name='nodeColor', values=node_colors)
        edge_colors = map_edge_to_color(G_overlap, 'edge_weight', False)
        nx.set_edge_attributes(G_overlap, name='edgeColor', values=edge_colors)
        visJS_module.export_to_cytoscape(G=G_overlap, export_file=export_file)

    return visJS_module.visjs_network(nodes_dict, edges_dict, **kwargs)