def to_networkx(data, node_attrs=None, edge_attrs=None, to_undirected=False, remove_self_loops=False): r"""Converts a :class:`torch_geometric.data.Data` instance to a :obj:`networkx.Graph` if :attr:`to_undirected` is set to :obj:`True`, or a directed :obj:`networkx.DiGraph` otherwise. Args: data (torch_geometric.data.Data): The data object. node_attrs (iterable of str, optional): The node attributes to be copied. (default: :obj:`None`) edge_attrs (iterable of str, optional): The edge attributes to be copied. (default: :obj:`None`) to_undirected (bool, optional): If set to :obj:`True`, will return a a :obj:`networkx.Graph` instead of a :obj:`networkx.DiGraph`. The undirected graph will correspond to the upper triangle of the corresponding adjacency matrix. (default: :obj:`False`) remove_self_loops (bool, optional): If set to :obj:`True`, will not include self loops in the resulting graph. (default: :obj:`False`) """ import networkx as nx if to_undirected: G = nx.Graph() else: G = nx.DiGraph() G.add_nodes_from(range(data.num_nodes)) node_attrs, edge_attrs = node_attrs or [], edge_attrs or [] values = {} for key, item in data(*(node_attrs + edge_attrs)): if torch.is_tensor(item): values[key] = item.squeeze().tolist() else: values[key] = item if isinstance(values[key], (list, tuple)) and len(values[key]) == 1: values[key] = item[0] for i, (u, v) in enumerate(data.edge_index.t().tolist()): if to_undirected and v > u: continue if remove_self_loops and u == v: continue G.add_edge(u, v) for key in edge_attrs: G[u][v][key] = values[key][i] for key in node_attrs: for i, feat_dict in G.nodes(data=True): feat_dict.update({key: values[key][i]}) return G
def to_networkx(data, node_attrs=None, edge_attrs=None, to_undirected: Union[bool, str] = False, remove_self_loops: bool = False): r"""Converts a :class:`torch_geometric.data.Data` instance to a :obj:`networkx.Graph` if :attr:`to_undirected` is set to :obj:`True`, or a directed :obj:`networkx.DiGraph` otherwise. Args: data (torch_geometric.data.Data): The data object. node_attrs (iterable of str, optional): The node attributes to be copied. (default: :obj:`None`) edge_attrs (iterable of str, optional): The edge attributes to be copied. (default: :obj:`None`) to_undirected (bool or str, optional): If set to :obj:`True` or "upper", will return a :obj:`networkx.Graph` instead of a :obj:`networkx.DiGraph`. The undirected graph will correspond to the upper triangle of the corresponding adjacency matrix. Similarly, if set to "lower", the undirected graph will correspond to the lower triangle of the adjacency matrix. (default: :obj:`False`) remove_self_loops (bool, optional): If set to :obj:`True`, will not include self loops in the resulting graph. (default: :obj:`False`) """ import networkx as nx if to_undirected: G = nx.Graph() else: G = nx.DiGraph() G.add_nodes_from(range(data.num_nodes)) node_attrs, edge_attrs = node_attrs or [], edge_attrs or [] values = {} for key, value in data(*(node_attrs + edge_attrs)): if torch.is_tensor(value): value = value if value.dim() <= 1 else value.squeeze(-1) values[key] = value.tolist() else: values[key] = value to_undirected = "upper" if to_undirected is True else to_undirected to_undirected_upper = True if to_undirected == "upper" else False to_undirected_lower = True if to_undirected == "lower" else False for i, (u, v) in enumerate(data.edge_index.t().tolist()): if to_undirected_upper and u > v: continue elif to_undirected_lower and u < v: continue if remove_self_loops and u == v: continue G.add_edge(u, v) for key in edge_attrs: G[u][v][key] = values[key][i] for key in node_attrs: for i, feat_dict in G.nodes(data=True): feat_dict.update({key: values[key][i]}) return G